def init_discriminator(self, args):
     # init D
     self.discriminator_model = FCDiscriminator(num_classes=2).cuda()
     self.interp = nn.Upsample(size=400, mode='bilinear')
     self.disc_criterion = SegmentationLosses(
         weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
     return
예제 #2
0
def load_models(mode, device, args):
    """

    :param mode: "SS" or "Discriminator"
    :param args:
    :return:
    """

    if mode == "SS":
        if args.network == "segnet_small":
            from models.SegNet import SegNet_Small
            model = SegNet_Small(args.channels,
                                 args.classes,
                                 args.skip_type,
                                 BR_bool=args.BR,
                                 separable_conv=args.SC)
            model = model.to(device)

        summary(model, (args.channels, args.image_size, args.image_size), args)
    elif mode == "Discriminator":
        from models.discriminator import FCDiscriminator
        model = FCDiscriminator(num_classes=NUM_CLASSES, )
    else:
        raise ValueError("Invalid mode {}!".format(mode))

    try:
        if args.checkpoint_SS:
            model.load_state_dict(torch.load(args.checkpoint_SS))
        if args.checkpoint_DNet:
            model.load_state_dict(torch.load(args.checkpoint_DNet))

    except Exception as e:
        print(e)
        sys.exit(0)

    return model
예제 #3
0
    def __init__(self, cfg, writer, logger, use_pseudo_label=False, modal_num=3, multimodal_merger=multimodal_merger):
        self.cfg = cfg
        self.writer = writer
        self.class_numbers = 19
        self.logger = logger
        cfg_model = cfg['model']
        self.cfg_model = cfg_model
        self.best_iou = -100
        self.iter = 0
        self.nets = []
        self.split_gpu = 0
        self.default_gpu = cfg['model']['default_gpu']
        self.PredNet_Dir = None
        self.valid_classes = cfg['training']['valid_classes']
        self.G_train = True
        self.cls_feature_weight = cfg['training']['cls_feature_weight']
        self.use_pseudo_label = use_pseudo_label
        self.modal_num = modal_num

        # cluster vectors & cuda initialization
        self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda()
        self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda()
        self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda()
        self.class_threshold_group = torch.full([self.modal_num + 1, 19], 0.6).cuda()

        self.disc_T = torch.FloatTensor([0.0]).cuda()

        #self.metrics = CustomMetrics(self.class_numbers)
        self.metrics = CustomMetrics(self.class_numbers, modal_num=self.modal_num, model=self)

        # multimodal / multi-branch merger
        self.multimodal_merger = multimodal_merger

        bn = cfg_model['bn']
        if bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        elif bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        elif bn == 'gn':
            BatchNorm = nn.GroupNorm
        else:
            raise NotImplementedError('batch norm choice {} is not implemented'.format(bn))

        if True:
            self.PredNet = DeepLab(
                    num_classes=19,
                    backbone=cfg_model['basenet']['version'],
                    output_stride=16,
                    bn=cfg_model['bn'],
                    freeze_bn=True,
                    modal_num=self.modal_num
                    ).cuda()
            self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet)
            self.PredNet_DP = self.init_device(self.PredNet, gpu_id=self.default_gpu, whether_DP=True) 
            self.PredNet.eval()
            self.PredNet_num = 0

            self.PredDnet = FCDiscriminator(inplanes=19)
            self.load_PredDnet(cfg, writer, logger, dir=None, net=self.PredDnet)
            self.PredDnet_DP = self.init_device(self.PredDnet, gpu_id=self.default_gpu, whether_DP=True)
            self.PredDnet.eval()

        self.BaseNet = DeepLab(
                            num_classes=19,
                            backbone=cfg_model['basenet']['version'],
                            output_stride=16,
                            bn=cfg_model['bn'],
                            freeze_bn=True, 
                            modal_num=self.modal_num
                            )

        logger.info('the backbone is {}'.format(cfg_model['basenet']['version']))

        self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True)
        self.nets.extend([self.BaseNet])
        self.nets_DP = [self.BaseNet_DP]

        # Discriminator
        self.SOURCE_LABEL = 0
        self.TARGET_LABEL = 1
        self.DNets = []
        self.DNets_DP = []
        for _ in range(self.modal_num+1):
            _net_d = FCDiscriminator(inplanes=19)
            self.DNets.append(_net_d)
            _net_d_DP = self.init_device(_net_d, gpu_id=self.default_gpu, whether_DP=True)
            self.DNets_DP.append(_net_d_DP)

        self.nets.extend(self.DNets)
        self.nets_DP.extend(self.DNets_DP)

        self.optimizers = []
        self.schedulers = []        

        optimizer_cls = torch.optim.SGD
        optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                            if k != 'name'}

        optimizer_cls_D = torch.optim.Adam
        optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items() 
                            if k != 'name'}

        if False:
            self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params)
        else:
            self.BaseOpti = optimizer_cls(self.BaseNet.optim_parameters(cfg['training']['optimizer']['lr']), **optimizer_params)

        self.optimizers.extend([self.BaseOpti])

        self.DiscOptis = []
        for _d_net in self.DNets: 
            self.DiscOptis.append(
                optimizer_cls_D(_d_net.parameters(), **optimizer_params_D)
            )
        self.optimizers.extend(self.DiscOptis)

        self.schedulers = []        

        if False:
            self.BaseSchedule = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule'])
            self.schedulers.extend([self.BaseSchedule])
        else:
            """BaseSchedule detail see FUNC: scheduler_step()"""
            self.learning_rate = cfg['training']['optimizer']['lr']
            self.gamma = cfg['training']['lr_schedule']['gamma']
            self.num_steps = cfg['training']['lr_schedule']['max_iter']
            self._BaseSchedule_nouse = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule'])
            self.schedulers.extend([self._BaseSchedule_nouse])

        self.DiscSchedules = []
        for _disc_opt in self.DiscOptis:
            self.DiscSchedules.append(
                get_scheduler(_disc_opt, cfg['training']['lr_schedule'])
            )
        self.schedulers.extend(self.DiscSchedules)
        self.setup(cfg, writer, logger)

        self.adv_source_label = 0
        self.adv_target_label = 1
        self.bceloss = nn.BCEWithLogitsLoss(reduce=False)
        self.loss_fn = get_loss_function(cfg)
        pseudo_cfg = copy.deepcopy(cfg)
        pseudo_cfg['training']['loss']['name'] = 'cross_entropy4d'
        self.pseudo_loss_fn = get_loss_function(pseudo_cfg)
        self.mseloss = nn.MSELoss()
        self.l1loss = nn.L1Loss()
        self.smoothloss = nn.SmoothL1Loss()
        self.triplet_loss = nn.TripletMarginLoss()
        self.kl_distance = nn.KLDivLoss(reduction='none')
예제 #4
0
class CustomModel():
    def __init__(self, cfg, writer, logger, use_pseudo_label=False, modal_num=3, multimodal_merger=multimodal_merger):
        self.cfg = cfg
        self.writer = writer
        self.class_numbers = 19
        self.logger = logger
        cfg_model = cfg['model']
        self.cfg_model = cfg_model
        self.best_iou = -100
        self.iter = 0
        self.nets = []
        self.split_gpu = 0
        self.default_gpu = cfg['model']['default_gpu']
        self.PredNet_Dir = None
        self.valid_classes = cfg['training']['valid_classes']
        self.G_train = True
        self.cls_feature_weight = cfg['training']['cls_feature_weight']
        self.use_pseudo_label = use_pseudo_label
        self.modal_num = modal_num

        # cluster vectors & cuda initialization
        self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda()
        self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda()
        self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda()
        self.class_threshold_group = torch.full([self.modal_num + 1, 19], 0.6).cuda()

        self.disc_T = torch.FloatTensor([0.0]).cuda()

        #self.metrics = CustomMetrics(self.class_numbers)
        self.metrics = CustomMetrics(self.class_numbers, modal_num=self.modal_num, model=self)

        # multimodal / multi-branch merger
        self.multimodal_merger = multimodal_merger

        bn = cfg_model['bn']
        if bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        elif bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        elif bn == 'gn':
            BatchNorm = nn.GroupNorm
        else:
            raise NotImplementedError('batch norm choice {} is not implemented'.format(bn))

        if True:
            self.PredNet = DeepLab(
                    num_classes=19,
                    backbone=cfg_model['basenet']['version'],
                    output_stride=16,
                    bn=cfg_model['bn'],
                    freeze_bn=True,
                    modal_num=self.modal_num
                    ).cuda()
            self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet)
            self.PredNet_DP = self.init_device(self.PredNet, gpu_id=self.default_gpu, whether_DP=True) 
            self.PredNet.eval()
            self.PredNet_num = 0

            self.PredDnet = FCDiscriminator(inplanes=19)
            self.load_PredDnet(cfg, writer, logger, dir=None, net=self.PredDnet)
            self.PredDnet_DP = self.init_device(self.PredDnet, gpu_id=self.default_gpu, whether_DP=True)
            self.PredDnet.eval()

        self.BaseNet = DeepLab(
                            num_classes=19,
                            backbone=cfg_model['basenet']['version'],
                            output_stride=16,
                            bn=cfg_model['bn'],
                            freeze_bn=True, 
                            modal_num=self.modal_num
                            )

        logger.info('the backbone is {}'.format(cfg_model['basenet']['version']))

        self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True)
        self.nets.extend([self.BaseNet])
        self.nets_DP = [self.BaseNet_DP]

        # Discriminator
        self.SOURCE_LABEL = 0
        self.TARGET_LABEL = 1
        self.DNets = []
        self.DNets_DP = []
        for _ in range(self.modal_num+1):
            _net_d = FCDiscriminator(inplanes=19)
            self.DNets.append(_net_d)
            _net_d_DP = self.init_device(_net_d, gpu_id=self.default_gpu, whether_DP=True)
            self.DNets_DP.append(_net_d_DP)

        self.nets.extend(self.DNets)
        self.nets_DP.extend(self.DNets_DP)

        self.optimizers = []
        self.schedulers = []        

        optimizer_cls = torch.optim.SGD
        optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                            if k != 'name'}

        optimizer_cls_D = torch.optim.Adam
        optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items() 
                            if k != 'name'}

        if False:
            self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params)
        else:
            self.BaseOpti = optimizer_cls(self.BaseNet.optim_parameters(cfg['training']['optimizer']['lr']), **optimizer_params)

        self.optimizers.extend([self.BaseOpti])

        self.DiscOptis = []
        for _d_net in self.DNets: 
            self.DiscOptis.append(
                optimizer_cls_D(_d_net.parameters(), **optimizer_params_D)
            )
        self.optimizers.extend(self.DiscOptis)

        self.schedulers = []        

        if False:
            self.BaseSchedule = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule'])
            self.schedulers.extend([self.BaseSchedule])
        else:
            """BaseSchedule detail see FUNC: scheduler_step()"""
            self.learning_rate = cfg['training']['optimizer']['lr']
            self.gamma = cfg['training']['lr_schedule']['gamma']
            self.num_steps = cfg['training']['lr_schedule']['max_iter']
            self._BaseSchedule_nouse = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule'])
            self.schedulers.extend([self._BaseSchedule_nouse])

        self.DiscSchedules = []
        for _disc_opt in self.DiscOptis:
            self.DiscSchedules.append(
                get_scheduler(_disc_opt, cfg['training']['lr_schedule'])
            )
        self.schedulers.extend(self.DiscSchedules)
        self.setup(cfg, writer, logger)

        self.adv_source_label = 0
        self.adv_target_label = 1
        self.bceloss = nn.BCEWithLogitsLoss(reduce=False)
        self.loss_fn = get_loss_function(cfg)
        pseudo_cfg = copy.deepcopy(cfg)
        pseudo_cfg['training']['loss']['name'] = 'cross_entropy4d'
        self.pseudo_loss_fn = get_loss_function(pseudo_cfg)
        self.mseloss = nn.MSELoss()
        self.l1loss = nn.L1Loss()
        self.smoothloss = nn.SmoothL1Loss()
        self.triplet_loss = nn.TripletMarginLoss()
        self.kl_distance = nn.KLDivLoss(reduction='none')

    def create_PredNet(self,):
        ss = DeepLab(
                num_classes=19,
                backbone=self.cfg_model['basenet']['version'],
                output_stride=16,
                bn=self.cfg_model['bn'],
                freeze_bn=True,
                modal_num=self.modal_num,
                ).cuda()
        ss.eval()
        return ss

    def setup(self, cfg, writer, logger):
        '''
        set optimizer and load pretrained model
        '''
        for net in self.nets:
            # name = net.__class__.__name__
            self.init_weights(cfg['model']['init'], logger, net)
            print("Initializition completed")
            if hasattr(net, '_load_pretrained_model') and cfg['model']['pretrained']:
                print("loading pretrained model for {}".format(net.__class__.__name__))
                net._load_pretrained_model()
        '''load pretrained model
        '''
        if cfg['training']['resume_flag']:
            self.load_nets(cfg, writer, logger)
        pass

    def lr_poly(self):
        return self.learning_rate * ((1 - float(self.iter) / self.num_steps) ** (self.gamma))

    def adjust_basenet_learning_rate(self):
        lr = self.lr_poly()
        self.BaseOpti.param_groups[0]['lr'] = lr
        if len(self.BaseOpti.param_groups) > 1:
            self.BaseOpti.param_groups[1]['lr'] = lr * 10

    def forward(self, input):
        feat, feat_low, att_mask, feat_cls, output = self.BaseNet_DP(input)
        return feat, feat_low, feat_cls, output

    def forward_Up(self, input):
        feat, feat_low, feat_cls, outputs = self.forward(input)
        merge_out = self.multimodal_merger(
            {
                'feat_cls': feat_cls,
                'output': output,
            },
            is_upsample=True,
            size=input.size()[2:],
        )
        return feat, feat_low, merge_out['feat_cls'], merge_out['output_comb']

    def PredNet_Forward(self, input):
        with torch.no_grad():
            _, _, att_mask, feat_cls, output_result = self.PredNet_DP(input)
        return _, _, att_mask, feat_cls, output_result

    def calculate_mean_vector(self, feat_cls, outputs, labels, ):
        outputs_softmax = F.softmax(outputs, dim=1)
        outputs_argmax = outputs_softmax.argmax(dim=1, keepdim=True)
        outputs_argmax = self.process_label(outputs_argmax.float())
        labels_expanded = self.process_label(labels)
        outputs_pred = labels_expanded * outputs_argmax
        scale_factor = F.adaptive_avg_pool2d(outputs_pred, 1)
        vectors = []
        ids = []
        for n in range(feat_cls.size()[0]):
            for t in range(self.class_numbers):
                if scale_factor[n][t].item()==0:
                    continue
                if (outputs_pred[n][t] > 0).sum() < 10:
                    continue
                s = feat_cls[n] * outputs_pred[n][t]
                scale = torch.sum(outputs_pred[n][t]) / labels.shape[2] / labels.shape[3] * 2
                s = normalisation_pooling()(s, scale)
                s = F.adaptive_avg_pool2d(s, 1) / scale_factor[n][t]
                vectors.append(s)
                ids.append(t)
        return vectors, ids

    def step(self, source_x, source_label, source_modal_ids, target_x, target_label, target_modal_ids, use_pseudo_loss=False):
        assert len(source_modal_ids) == source_x.size(0), "modal_ids' batchsize != source_x's batchsize"
        _, _, source_feat_cls, source_output = self.forward(input=source_x) 
        """source_output: [B x 19 x W x H, ...]
        select modal-branch output in each batchsize
        Specific-modal output
        """
        source_output_modal_k = torch.stack(
            [
                source_output[_modal_i][_batch_i]
                for _batch_i, _modal_i in enumerate(source_modal_ids)
            ], 
            dim=0,
        )
        # attention output & specific-modal output
        source_output_comb = torch.cat([source_output_modal_k, source_output[-1]], dim=0)

        source_label_comb = torch.cat([source_label, source_label.clone()], dim=0)

        source_outputUp = F.interpolate(source_output_comb, size=source_x.size()[-2:], mode='bilinear', align_corners=True)

        loss_GTA = self.loss_fn(input=source_outputUp, target=source_label_comb)

        self.PredNet.eval()
        with torch.no_grad():
            _, _, att_mask, feat_cls, output = self.PredNet_Forward(target_x)

            threshold_args_comb, cluster_args_comb = self.metrics.update(feat_cls, output, target_label, modal_ids=[_i for _i in range(self.modal_num+1)], att_mask=att_mask)

            """ Discriminator-guided easy/hard training """
            target_label_size = target_label.size()
            t_out = output[-1]
            _t_out = F.interpolate(t_out.detach(), size=(target_label_size[1]*4, target_label_size[2]*4), mode='bilinear', align_corners=True)
            _t_D_out = self.PredDnet_DP(F.softmax(_t_out))
            _t_D_out_prob = F.sigmoid(_t_D_out)

            disc_easy_weight = torch.where(_t_D_out_prob > self.disc_T, _t_D_out_prob, torch.FloatTensor([0.0]).cuda())
            disc_easy_weight = torch.where(threshold_args_comb != 250, disc_easy_weight, torch.FloatTensor([0.0]).cuda()).squeeze(1)

            disc_hard_mask = torch.where(_t_D_out_prob < self.disc_T, torch.Tensor([1]).cuda(), torch.Tensor([0]).cuda())
            disc_hard_mask = torch.where(threshold_args_comb == 250, torch.Tensor([1]).cuda(), disc_hard_mask)


        loss_L2_source_cls = torch.Tensor([0]).cuda(self.split_gpu)
        loss_L2_target_cls = torch.Tensor([0]).cuda(self.split_gpu)
        _, _, target_feat_cls, target_output = self.forward(target_x)

        if self.cfg['training']['loss_L2_cls']:     # distance loss
            _batch, _w, _h = source_label.shape
            source_label_downsampled = source_label.reshape([_batch,1,_w, _h]).float()
            source_label_downsampled = F.interpolate(source_label_downsampled.float(), size=source_feat_cls[0].size()[-2:], mode='nearest')   #or F.softmax(input=source_output, dim=1)

            loss_L2_source_cls = torch.Tensor([0]).cuda()
            loss_L2_target_cls = torch.Tensor([0]).cuda()
            for _modal_i, _source_feat_i, _source_out_i, _target_feat_i, _target_out_i in zip(range(self.modal_num + 1), source_feat_cls, source_output, target_feat_cls, target_output):
                if _modal_i < 2:
                    continue
                source_vectors, source_ids = self.calculate_mean_vector(_source_feat_i, _source_out_i, source_label_downsampled)
                loss_L2_source_cls += self.class_vectors_alignment(source_ids, source_vectors, modal_ids=[_modal_i,])

                target_vectors, target_ids = self.calculate_mean_vector(_target_feat_i, _target_out_i, cluster_args_comb.float())
                loss_L2_target_cls += self.class_vectors_alignment(target_ids, target_vectors, modal_ids=[_modal_i,])

        loss_L2_cls = self.cls_feature_weight * (loss_L2_source_cls + loss_L2_target_cls)
        if loss_L2_cls.item() > 1.0:
            loss_L2_cls = loss_L2_cls / 10.0
        
        if loss_L2_cls.item() > 0.5:
            loss_L2_cls = loss_L2_cls / 3.0

        target_label_size = target_label.size()

        loss = torch.Tensor([0]).cuda()
        batch, _, w, h = threshold_args_comb.shape
        _cluster_args_comb = cluster_args_comb.reshape([batch, w, h])
        _threshold_args_comb = threshold_args_comb.reshape([batch, w, h])
        _target_output = target_output[-1]

        _loss_CTS = self.pseudo_loss_fn(input=_target_output, target=_threshold_args_comb)  # CAG-based and probability-based PLA
        loss_CTS = torch.sum(_loss_CTS * disc_easy_weight) / (1 + (disc_easy_weight > 0).sum())

        if self.G_train and self.cfg['training']['loss_pseudo_label']:
            loss = loss + loss_CTS
        if self.G_train and self.cfg['training']['loss_source_seg']:
            loss = loss + loss_GTA
        if self.cfg['training']['loss_L2_cls']:
            loss = loss + torch.sum(loss_L2_cls)

        # adversarial loss
        # -----------------------------
        """Generator (segmentation)"""
        # -----------------------------

        # On Source Domain 
        loss_adv = torch.Tensor([0]).cuda()
        _batch_size = 0

        source_modal_ids_tensor = torch.Tensor(source_modal_ids).cuda()
        target_modal_ids_tensor = torch.Tensor(target_modal_ids).cuda()
        for t_out, _d_net_DP, _d_net, modal_idx in zip(target_output, self.DNets_DP, self.DNets, range(len(target_output))):
            # set grad false
            self.set_requires_grad(self.logger, _d_net, requires_grad = False)
            t_D_out = _d_net_DP(F.softmax(t_out))
            _disc_hard_mask = F.interpolate(disc_hard_mask, size=(t_D_out.size(2), t_D_out.size(3)), mode='nearest')
            #source_modal_ids
            loss_temp = torch.sum(self.bceloss(
                t_D_out,
                torch.FloatTensor(t_D_out.data.size()).fill_(1.0).cuda()
            ) * _disc_hard_mask, [1,2,3]) / (torch.sum(disc_hard_mask, [1,2,3]) + 1)

            if modal_idx >= self.modal_num:
                loss_adv += torch.mean(loss_temp)
            elif torch.mean(torch.as_tensor((modal_idx==target_modal_ids_tensor), dtype=torch.float32)) == 0:
                loss_adv += 0.0
            else:
                loss_adv += torch.mean(torch.masked_select(loss_temp, target_modal_ids_tensor==modal_idx))

            _batch_size += t_out.size(0)

        loss_adv *= self.cfg['training']['loss_adv_lambda']

        loss_G = torch.Tensor([0]).cuda()
        loss_G = loss_G + loss_adv
        loss = loss + loss_G
        if loss.item() != 0:
            loss.backward()

        self.BaseOpti.step()
        self.BaseOpti.zero_grad()

        # -----------------------------
        """Discriminator """
        # -----------------------------            
        _batch_size = 0
        loss_D_comb = torch.Tensor([0]).cuda()
        source_label_size = source_label.size()
        for s_out, t_out, _d_net_DP, _d_net, _disc_opt, modal_idx in zip(source_output, target_output, self.DNets_DP, self.DNets, self.DiscOptis, range(len(source_output))):
            self.set_requires_grad(self.logger, _d_net, requires_grad = True)
            
            _batch_size = 0
            loss_D = torch.Tensor([0]).cuda()
            # source domain
            s_D_out = _d_net_DP(F.softmax(s_out.detach()))

            loss_temp_s = torch.mean(self.bceloss(
                s_D_out,
                torch.FloatTensor(s_D_out.data.size()).fill_(1.0).cuda()
            ), [1,2,3])

            if modal_idx >= self.modal_num:
                loss_D += torch.mean(loss_temp_s)
            elif torch.mean(torch.as_tensor((modal_idx==source_modal_ids_tensor), dtype=torch.float32)) == 0:
                loss_D += 0.0
            else:
                loss_D += torch.mean(torch.masked_select(loss_temp_s, source_modal_ids_tensor==modal_idx))

            # target domain
            _batch_size += (s_out.size(0) + t_out.size(0))
            t_D_out = _d_net_DP(F.softmax(t_out.detach()))
            loss_temp_t = torch.mean(self.bceloss(
                t_D_out,
                torch.FloatTensor(t_D_out.data.size()).fill_(0.0).cuda()
            ), [1,2,3])

            if modal_idx >= self.modal_num:
                loss_D += torch.mean(loss_temp_t)
            elif torch.mean(torch.as_tensor((modal_idx==target_modal_ids_tensor), dtype=torch.float32)) == 0:
                loss_D += 0.0
            else:
                loss_D += torch.mean(torch.masked_select(loss_temp_t, target_modal_ids_tensor==modal_idx))

            loss_D *= self.cfg['training']['loss_adv_lambda']*0.5
            if loss_D.item() != 0:
                loss_D.backward()

            _disc_opt.step()
            _disc_opt.zero_grad()

            loss_D_comb += loss_D

        return loss, loss_adv, loss_D_comb


    def process_label(self, label):
        batch, channel, w, h = label.size()
        pred1 = torch.zeros(batch, 20, w, h).cuda()
        id = torch.where(label < 19, label, torch.Tensor([19]).cuda())
        pred1 = pred1.scatter_(1, id.long(), 1)
        return pred1

    def class_vectors_alignment(self, ids, vectors, modal_ids=[0,]):
        loss = torch.Tensor([0]).cuda()

        """construct category objective vectors"""
        # objective_vectors_group 2 x 19 x 256 --> 19 x 512
        _objective_vectors_set = self.metrics.multimodal_merger.merge_objective_vectors(modal_ids=modal_ids)

        for i in range(len(ids)):
            if ids[i] not in self.valid_classes:
                continue
            new_loss = self.smoothloss(vectors[i].squeeze().cuda(), _objective_vectors_set[ids[i]])
            while (new_loss.item() > 5):
                new_loss = new_loss / 10
            loss = loss + new_loss
        loss = loss / len(ids) * 10
        return loss

    def freeze_bn_apply(self):
        for net in self.nets:
            net.apply(freeze_bn)
        for net in self.nets_DP:
            net.apply(freeze_bn)

    def scheduler_step(self):
        if self.use_pseudo_label:
            for scheduler in self.schedulers:
                scheduler.step()
        else:
            """skipped _BaseScheduler_nouse"""
            for scheduler in self.schedulers[1:]:
                scheduler.step()
            self.adjust_basenet_learning_rate()
    
    def optimizer_zerograd(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()
    
    def optimizer_step(self):
        for opt in self.optimizers:
            opt.step()

    def init_device(self, net, gpu_id=None, whether_DP=False):
        gpu_id = gpu_id or self.default_gpu
        device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else 'cpu')
        net = net.to(device)

        if whether_DP:
            net = DataParallelWithCallback(net, device_ids=range(torch.cuda.device_count()))
        return net
    
    def eval(self, net=None, logger=None):
        """Make specific models eval mode during test time"""
        if net == None:
            for net in self.nets:
                net.eval()
            for net in self.nets_DP:
                net.eval()
            if logger!=None:    
                logger.info("Successfully set the model eval mode") 
        else:
            net.eval()
            if logger!=None:    
                logger("Successfully set {} eval mode".format(net.__class__.__name__))
        return

    def train(self, net=None, logger=None):
        if net==None:
            for net in self.nets:
                net.train()
            for net in self.nets_DP:
                net.train()
        else:
            net.train()
        return

    def set_requires_grad(self, logger, net, requires_grad = False):
        """Set requires_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            net (BaseModel)       -- the network which will be operated on
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        for parameter in net.parameters():
            parameter.requires_grad = requires_grad
        
    def set_requires_grad_layer(self, logger, net, layer_type='batchnorm', requires_grad=False):  
        '''    set specific type of layers whether needing grad
        '''

        # print('Warning: all the BatchNorm params are fixed!')
        # logger.info('Warning: all the BatchNorm params are fixed!')
        for net in self.nets:
            for _i in net.modules():
                if _i.__class__.__name__.lower().find(layer_type.lower()) != -1:
                    _i.weight.requires_grad = requires_grad
        return

    def init_weights(self, cfg, logger, net, init_type='normal', init_gain=0.02):
        """Initialize network weights.

        Parameters:
            net (network)   -- network to be initialized
            init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
            init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

        We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
        work better for some applications. Feel free to try yourself.
        """
        init_type = cfg.get('init_type', init_type)
        init_gain = cfg.get('init_gain', init_gain)
        def init_func(m):  # define the initialization function
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif isinstance(m, SynchronizedBatchNorm2d) or classname.find('BatchNorm2d') != -1 \
                or isinstance(m, nn.GroupNorm):
                # or isinstance(m, InPlaceABN) or isinstance(m, InPlaceABNSync):
                m.weight.data.fill_(1)
                m.bias.data.zero_() # BatchNorm Layer's weight is not a matrix; only normal distribution applies.


        print('initialize {} with {}'.format(init_type, net.__class__.__name__))
        logger.info('initialize {} with {}'.format(init_type, net.__class__.__name__))
        net.apply(init_func)  # apply the initialization function <init_func>
        pass

    def adaptive_load_nets(self, net, model_weight):
        model_dict = net.state_dict()
        pretrained_dict = {k : v for k, v in model_weight.items() if k in model_dict}
        
        print("[INFO] Pretrained dict:", pretrained_dict.keys())
        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)

    def load_nets(self, cfg, writer, logger):    # load pretrained weights on the net
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            _k = -1
            net_state_no = {}
            for net in self.nets:
                name = net.__class__.__name__
                if name not in net_state_no:
                    net_state_no[name] = 0
                else:
                    net_state_no[name] += 1
                _k += 1
                if checkpoint.get(name) == None:
                    continue
                if name.find('FCDiscriminator') != -1 and cfg['training']['gan_resume'] == False:
                    continue
                #self.adaptive_load_nets(net, checkpoint[name]["model_state"])
                if isinstance(checkpoint[name], list):
                    self.adaptive_load_nets(net, checkpoint[name][net_state_no[name]]["model_state"])
                else:
                    print("*****************************************")
                    print("[WARNING] Using depreciated load version! Model {}".format(name))
                    print("*****************************************")
                    self.adaptive_load_nets(net, checkpoint[name]["model_state"])
                if cfg['training']['optimizer_resume']:
                    if isinstance(checkpoint[name], list):
                        self.adaptive_load_nets(self.optimizers[_k], checkpoint[name][net_state_no[name]]["optimizer_state"])
                        self.adaptive_load_nets(self.schedulers[_k], checkpoint[name][net_state_no[name]]["scheduler_state"])
                    else:
                        self.adaptive_load_nets(self.optimizers[_k], checkpoint[name]["optimizer_state"])
                        self.adaptive_load_nets(self.schedulers[_k], checkpoint[name]["scheduler_state"])
            self.iter = checkpoint["iter"] if 'iter' in checkpoint else 0
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], self.iter
                )
            )
        else:
            raise Exception("No checkpoint found at '{}'".format(cfg['training']['resume']))


    def load_PredNet(self, cfg, writer, logger, dir=None, net=None):    # load pretrained weights on the net
        dir = dir or cfg['training']['Pred_resume']
        best_iou = 0
        if os.path.isfile(dir):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(dir)
            )
            checkpoint = torch.load(dir)
            name = net.__class__.__name__
            if checkpoint.get(name) == None:
                return
            if name.find('FCDiscriminator') != -1 and cfg['training']['gan_resume'] == False:
                return
            if isinstance(checkpoint[name], list):
                self.adaptive_load_nets(net, checkpoint[name][0]["model_state"])
            else:
                self.adaptive_load_nets(net, checkpoint[name]["model_state"])
            if 'iter' in checkpoint:
                checkpoint_iter = checkpoint["iter"]
            else:
                checkpoint_iter = 0
            if 'best_iou' in checkpoint:
                best_iou = checkpoint['best_iou']
            else:
                best_iou = 0
            logger.info(
                "Loaded checkpoint '{}' (iter {}) (best iou {}) for PredNet".format(
                    dir, checkpoint_iter, best_iou
                )
            )
        else:
            raise Exception("No checkpoint found at '{}'".format(dir))
        if hasattr(net, 'best_iou'):
            pass
        return best_iou

    def load_PredDnet(self, cfg, writer, logger, dir=None, net=None):    # load pretrained weights on the net
        dir = dir or cfg['training']['Pred_resume']
        best_iou = 0
        if os.path.isfile(dir):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(dir)
            )
            checkpoint = torch.load(dir)
            name = net.__class__.__name__
            if checkpoint.get(name) == None:
                return
            if name.find('FCDiscriminator') != -1 and cfg['training']['gan_resume'] == False:
                return
            if isinstance(checkpoint[name], list):
                self.adaptive_load_nets(net, checkpoint[name][-1]["model_state"]) # attention-branch discriminator
            else:
                print("[WARNING] load discriminator maybe error!")
                self.adaptive_load_nets(net, checkpoint[name]["model_state"])
            print("[INFO] {}: {}".format(name, net))
            iter = checkpoint["iter"]
            logger.info(
                "Loaded checkpoint '{}' (iter {}) for PredNet".format(
                    dir, checkpoint["iter"]
                )
            )
        else:
            raise Exception("No checkpoint found at '{}'".format(dir))
        return best_iou


    def set_optimizer(self, optimizer):  #set optimizer to all nets
        pass

    def reset_objective_SingleVector(self,):
        self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda()
        self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda()
        self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda()

    def update_objective_SingleVector(self, vectors, vectors_num, name='moving_average'):
        if torch.sum(vectors) == 0:
            return
        """
        if self.objective_vectors_num_group[modal_idx][id] < 100:
            name = 'mean'
        """
        if name == 'moving_average':
            self.objective_vectors_group = self.objective_vectors_group * 0.9999 + 0.0001 * vectors
            self.objective_vectors_num_group += vectors_num
            self.objective_vectors_num_group = min(self.objective_vectors_num_group, 3000)
        elif name == 'mean':
            self.objective_vectors_group = self.objective_vectors_group * self.objective_vectors_num_group.view(-1, 19, 1).expand(self.modal_num+1, 19, 256) + vectors
            self.objective_vectors_num_group = self.objective_vectors_num_group + vectors_num
            _objective_vectors_num_group = self.objective_vectors_num_group.clone()
            _ids = torch.where(_objective_vectors_num_group == 0)
            _objective_vectors_num_group[_ids] = 1.0
            self.objective_vectors_group = self.objective_vectors_group / _objective_vectors_num_group.view(-1, 19, 1).expand(self.modal_num+1, 19, 256)
            self.objective_vectors_num_group = torch.min(self.objective_vectors_num_group, torch.Tensor([3000]).cuda())
        else:
            raise NotImplementedError('no such updating way of objective vectors {}'.format(name))
class multisource_metatrainer(object):
    def __init__(self,
                 args,
                 nnclass,
                 meta_update_lr,
                 meta_update_step,
                 beta,
                 pretrain_mode='meta'):
        self.device = 1
        self.generator_model = None
        self.generator_optim = None
        self.generator_criterion = None
        self.pretrain_mode = pretrain_mode
        self.batch_size = args.batch_size
        self.nnclass = nnclass
        self.init_generator(args)
        self.init_discriminator(args)
        self.init_optimizer(args)
        self.meta_update_lr = meta_update_lr
        self.meta_update_step = meta_update_step
        self.beta = beta

    def init_generator(self, args):

        self.generator_model = DeepLab(num_classes=self.nnclass,
                                       backbone='resnet',
                                       output_stride=16,
                                       sync_bn=None,
                                       freeze_bn=False).cuda()

        self.generator_model = torch.nn.DataParallel(
            self.generator_model).cuda()
        patch_replication_callback(self.generator_model)
        if args.resume:
            print('#--------- load pretrained model --------------#')
            model_dict = self.generator_model.module.state_dict()
            checkpoint = torch.load(args.resume)
            pretrained_dict = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'last_conv' not in k and k in model_dict.keys()
            }
            #pretrained_dict = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()  if 'last_conv' not in k}
            model_dict.update(pretrained_dict)
            self.generator_model.module.load_state_dict(model_dict)
        for param in self.generator_model.parameters():
            param.requires_grad = True

    def init_discriminator(self, args):
        # init D
        self.discriminator_model = FCDiscriminator(num_classes=2).cuda()
        self.interp = nn.Upsample(size=400, mode='bilinear')
        self.disc_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        return

    def init_optimizer(self, args):
        self.generator_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(
                mode='bce')  #torch.nn.BCELoss(reduce ='mean')
        self.generator_params = [{
            'params':
            self.generator_model.module.get_1x_lr_params(),
            'lr':
            args.lr
        }, {
            'params':
            self.generator_model.module.get_10x_lr_params(),
            'lr':
            args.lr * 10
        }]
        self.discriminator_params = [{
            'params':
            self.discriminator_model.parameters(),
            'lr':
            args.lr * 5
        }]
        self.model_optim = torch.optim.Adadelta(self.generator_params +
                                                self.discriminator_params)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      lr_step=30,
                                      iters_per_epoch=100)

    # for madan the src_image has shape B x source_index x channel x H x W
    def update_weights(self, srca, srca_labels, src_b, srcb_labels, target_img,
                       target_label):
        #self.pretrain_mode = 'meta'
        src_labels = torch.cat([srca_labels.squeeze(),
                                srcb_labels.squeeze()],
                               0).type(torch.LongTensor).cuda()
        src_image = torch.cat([srca.squeeze(), src_b.squeeze()])
        if self.pretrain_mode == 'meta':
            seg_loss = self.meta_mldg(src_image, src_labels, self.batch_size)
        else:
            print('a default training is enabled')
            src_out, source_feature = self.generator_model(src_image)
            seg_loss = self.generator_criterion(src_out, src_labels)
        self.model_optim.zero_grad()
        seg_loss.backward()
        self.model_optim.step()
        target_logit, _ = self.generator_model(target_img.cuda())
        tgt_loss = self.generator_criterion(target_logit, target_label)
        tgt_loss = tgt_loss.detach()
        seg_loss = seg_loss.detach()
        return seg_loss, tgt_loss

    def meta_mldg(self, src_image, src_labels, batch_size):
        batch_size = 4
        num_src = 2
        S = np.random.choice(num_src)
        V = abs(S - 1)
        source_out, _ = self.generator_model(src_image[S * batch_size:(S + 1) *
                                                       batch_size].squeeze())
        losses = self.generator_criterion(
            source_out, src_labels[S * batch_size:(S + 1) * batch_size])
        for k in range(1, self.meta_update_step):
            source_out, _ = self.generator_model(
                src_image[S * batch_size:(S + 1) * batch_size].squeeze())
            loss = self.generator_criterion(
                source_out, src_labels[S * batch_size:(S + 1) * batch_size])
            grad = torch.autograd.grad(loss, self.generator_model.parameters())
            fast_weights = list(
                map(lambda p: p[1] - self.meta_update_lr * p[0],
                    zip(grad, self.generator_model.parameters())))
            # compute the test loss on the fast weights
            Grad_test = self.generator_model(src_image[V * batch_size:(V + 1) *
                                                       batch_size],
                                             fast_weights,
                                             bn_training=True)
            # compute the gradient on generator_model
            losses += self.beta * Grad_test
        return losses
예제 #6
0
class madan_trainer(object):
    def __init__(self, args, nnclass, ndomains):
        self.device = 1
        self.generator_model = None
        self.generator_optim = None
        self.generator_criterion = None
        self.batch_size = args.batch_size
        self.nnclass = nnclass
        self.num_domains = ndomains
        self.init_wasserstein = Wasserstein()
        self.init_generator(args)
        self.init_discriminator(args)
        self.init_optimizer(args)

    def init_generator(self, args):

        self.generator_model = DeepLab(num_classes=self.nnclass,
                                       backbone='resnet',
                                       output_stride=16,
                                       sync_bn=None,
                                       freeze_bn=False).cuda()

        self.generator_model = torch.nn.DataParallel(
            self.generator_model).cuda()
        patch_replication_callback(self.generator_model)
        if args.resume:
            print('#--------- load pretrained model --------------#')
            model_dict = self.generator_model.module.state_dict()
            checkpoint = torch.load(args.resume)
            pretrained_dict = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'last_conv' not in k and k in model_dict.keys()
            }
            #pretrained_dict = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()  if 'last_conv' not in k}
            model_dict.update(pretrained_dict)
            self.generator_model.module.load_state_dict(model_dict)
        for param in self.generator_model.parameters():
            param.requires_grad = True

    def init_discriminator(self, args):
        # init D
        self.discriminator_model = FCDiscriminator(num_classes=2).cuda()
        self.interp = nn.Upsample(size=400, mode='bilinear')
        self.disc_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        return

    def init_optimizer(self, args):
        self.generator_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(
                mode='bce')  #torch.nn.BCELoss(reduce ='mean')
        self.generator_params = [{
            'params':
            self.generator_model.module.get_1x_lr_params(),
            'lr':
            args.lr
        }, {
            'params':
            self.generator_model.module.get_10x_lr_params(),
            'lr':
            args.lr * 10
        }]
        self.discriminator_params = [{
            'params':
            self.discriminator_model.parameters(),
            'lr':
            args.lr * 5
        }]
        self.model_optim = torch.optim.Adadelta(self.generator_params +
                                                self.discriminator_params)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      lr_step=30,
                                      iters_per_epoch=100)

    # for madan the src_image has shape B x source_index x channel x H x W
    def update_weights(self, src_image, src_labels, targ_image, targ_labels,
                       options):
        running_loss = 0.0
        src_labels = torch.cat(
            [src_labels[:, 0].squeeze(), src_labels[:, 1].squeeze()],
            0).type(torch.LongTensor).cuda()
        self.model_optim.zero_grad()
        # src image shape batch_size x domain x 3 channels x height x width
        src_out, source_feature = self.generator_model(
            torch.cat([src_image[:, 0].squeeze(), src_image[:, 1].squeeze()]))
        targ_out, target_feature = self.generator_model(targ_image)
        #  Discriminator
        discriminator_x = torch.cat([source_feature, target_feature]).squeeze()
        disc_clf = self.discriminator_model(discriminator_x)
        # Losses
        losses = torch.stack([
            self.generator_criterion(
                src_out[j * self.batch_size:j + self.batch_size],
                src_labels[j * self.batch_size:j + self.batch_size])
            for j in range(self.num_domains)
        ])
        slabels = torch.ones(self.batch_size,
                             disc_clf.shape[2],
                             disc_clf.shape[3],
                             requires_grad=False).type(
                                 torch.LongTensor).cuda()
        tlabels = torch.zeros(self.batch_size * 2,
                              disc_clf.shape[2],
                              disc_clf.shape[3],
                              requires_grad=False).type(
                                  torch.LongTensor).cuda()
        domain_losses = torch.stack([
            self.generator_criterion(
                disc_clf[j * self.batch_size:j + self.batch_size].squeeze(),
                slabels) for j in range(self.num_domains)
        ])
        domain_losses = torch.cat([
            domain_losses,
            self.generator_criterion(
                disc_clf[2 * self.batch_size:2 * self.batch_size +
                         2 * self.batch_size].squeeze(), tlabels).view(-1)
        ])
        # Different final loss function depending on different training modes.
        if options['mode'] == "maxmin":
            loss = torch.max(losses) + options['mu'] * torch.min(domain_losses)
        elif options['mode'] == "dynamic":
            loss = torch.log(
                torch.sum(
                    torch.exp(options['gamma'] *
                              (losses + options['mu'] * domain_losses)))
            ) / options['gamma']
        else:
            raise ValueError(
                "No support for the training mode on madnNet: {}.".format(
                    options['mode']))
        loss.backward()
        self.model_optim.step()
        running_loss += loss.detach().cpu().numpy()
        # compute target loss
        target_loss = self.generator_criterion(
            targ_out, targ_labels).detach().cpu().numpy()
        return running_loss, target_loss

    def update_wasserstein(self, src_image, src_labels, targ_image,
                           targ_labels, options):
        running_loss = 0.0
        src_labels = torch.cat(
            [src_labels[:, 0].squeeze(), src_labels[:, 1].squeeze()],
            0).type(torch.LongTensor).cuda()
        self.model_optim.zero_grad()
        # src image shape batch_size x domain x 3 channels x height x width
        src_out, source_feature = self.generator_model(
            torch.cat([src_image[:, 0].squeeze(), src_image[:, 1].squeeze()]))
        targ_out, target_feature = self.generator_model(targ_image)
        #  Discriminator
        discriminator_x = torch.cat([source_feature, target_feature]).squeeze()
        disc_clf = self.discriminator_model(discriminator_x)
        # Losses
        losses = torch.stack([
            self.generator_criterion(
                src_out[j * self.batch_size:j * self.batch_size +
                        self.batch_size],
                src_labels[j * self.batch_size:j * self.batch_size +
                           self.batch_size]) for j in range(self.num_domains)
        ])
        wass_loss = [
            self.init_wasserstein.update_wasserstein_dual_source(
                disc_clf[j * self.batch_size:j * self.batch_size +
                         self.batch_size].squeeze(),
                disc_clf[self.num_domains *
                         self.batch_size:self.num_domains * self.batch_size +
                         self.batch_size].squeeze())
            for j in range(self.num_domains)
        ]

        domain_losses = torch.stack(wass_loss)
        # compute gradient penalty
        penalty_cup, penalty_disc = self.init_wasserstein.gradient_regularization_dual_source(
            self.discriminator_model, source_feature.detach(),
            target_feature.detach(), options['batch_size'],
            options['num_domains'])
        # Different final loss function depending on different training modes.
        if options['mode'] == "maxmin":
            loss = torch.max(
                losses) + options['mu'] * torch.min(domain_losses) + options[
                    'gamma'] * penalty_cup + options['gamma'] * penalty_disc
        elif options['mode'] == "dynamic":
            # TODO Wasserstein not implemented yet for this
            loss = torch.log(
                torch.sum(
                    torch.exp(options['gamma'] *
                              (losses + options['mu'] * domain_losses)))
            ) / options['gamma']
        else:
            raise ValueError(
                "No support for the training mode on madnNet: {}.".format(
                    options['mode']))
        loss.backward()
        self.model_optim.step()
        for p in self.discriminator_model.parameters():
            p.data.clamp_(-0.01, 0.01)
        running_loss += loss.detach().cpu().numpy()
        # compute target loss
        target_loss = self.generator_criterion(
            targ_out, targ_labels).detach().cpu().numpy()
        return running_loss, target_loss
예제 #7
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    cudnn.enabled = True
    n_discriminators = 5

    # create teacher & student
    student_net = UNet(3, n_classes=args.num_classes)
    teacher_net = UNet(3, n_classes=args.num_classes)
    student_params = list(student_net.parameters())

    # teacher doesn't need gradient as it's just a EMA of the student
    teacher_params = list(teacher_net.parameters())
    for param in teacher_params:
        param.requires_grad = False

    student_net.train()
    student_net.cuda(args.gpu)
    teacher_net.train()
    teacher_net.cuda(args.gpu)

    cudnn.benchmark = True
    unsup_weights = [
        args.unsup_weight5, args.unsup_weight6, args.unsup_weight7,
        args.unsup_weight8, args.unsup_weight9
    ]
    lambda_adv_tgts = [
        args.lambda_adv_tgt5, args.lambda_adv_tgt6, args.lambda_adv_tgt7,
        args.lambda_adv_tgt8, args.lambda_adv_tgt9
    ]

    # create a list of discriminators
    discriminators = []
    for dis_idx in range(n_discriminators):
        discriminators.append(FCDiscriminator(num_classes=args.num_classes))
        discriminators[dis_idx].train()
        discriminators[dis_idx].cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    max_iters = args.num_steps * args.iter_size * args.batch_size
    src_set = REFUGE(True,
                     domain='REFUGE_SRC',
                     is_transform=True,
                     augmentations=aug_student,
                     aug_for_target=aug_teacher,
                     max_iters=max_iters)
    src_loader = data.DataLoader(src_set,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    src_loader_iter = enumerate(src_loader)
    tgt_set = REFUGE(True,
                     domain='REFUGE_DST',
                     is_transform=True,
                     augmentations=aug_student,
                     aug_for_target=aug_teacher,
                     max_iters=max_iters)
    tgt_loader = data.DataLoader(tgt_set,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_loader_iter = enumerate(tgt_loader)
    student_optimizer = optim.SGD(student_params,
                                  lr=args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    teacher_optimizer = optim_weight_ema.WeightEMA(teacher_params,
                                                   student_params,
                                                   alpha=args.teacher_alpha)

    d_optimizers = []
    for idx in range(n_discriminators):
        optimizer = optim.Adam(discriminators[idx].parameters(),
                               lr=args.learning_rate_D,
                               betas=(0.9, 0.99))
        d_optimizers.append(optimizer)

    calc_bce_loss = torch.nn.BCEWithLogitsLoss()

    # labels for adversarial training
    source_label, tgt_label = 0, 1
    for i_iter in range(args.num_steps):

        total_seg_loss = 0
        seg_loss_vals = [0] * n_discriminators
        adv_tgt_loss_vals = [0] * n_discriminators
        d_loss_vals = [0] * n_discriminators
        unsup_loss_vals = [0] * n_discriminators

        for d_optimizer in d_optimizers:
            d_optimizer.zero_grad()
            adjust_learning_rate_D(d_optimizer, i_iter, args)

        student_optimizer.zero_grad()
        adjust_learning_rate(student_optimizer, i_iter, args)

        for sub_i in range(args.iter_size):

            # ******** Optimize source network with segmentation loss ********
            # As we don't change the discriminators, their parameters are fixed
            for discriminator in discriminators:
                for param in discriminator.parameters():
                    param.requires_grad = False

            _, src_batch = src_loader_iter.__next__()
            _, _, src_images, src_labels, _ = src_batch
            src_images = Variable(src_images).cuda(args.gpu)

            # calculate the segmentation losses
            sup_preds = list(student_net(src_images))
            seg_losses, total_seg_loss = [], 0
            for idx, sup_pred in enumerate(sup_preds):
                sup_interp_pred = (sup_pred)
                # you also can use dice loss like: dice_loss(src_labels, sup_interp_pred)
                seg_loss = Weighted_Jaccard_loss(src_labels, sup_interp_pred,
                                                 args.class_weights, args.gpu)
                seg_losses.append(seg_loss)
                total_seg_loss += seg_loss * unsup_weights[idx]
                seg_loss_vals[idx] += seg_loss.item() / args.iter_size

            _, tgt_batch = tgt_loader_iter.__next__()
            tgt_images0, tgt_lbl0, tgt_images1, tgt_lbl1, _ = tgt_batch
            tgt_images0 = Variable(tgt_images0).cuda(args.gpu)
            tgt_images1 = Variable(tgt_images1).cuda(args.gpu)

            # calculate ensemble losses
            stu_unsup_preds = list(student_net(tgt_images1))
            tea_unsup_preds = teacher_net(tgt_images0)
            total_mse_loss = 0
            for idx in range(n_discriminators):
                stu_unsup_probs = F.softmax(stu_unsup_preds[idx], dim=-1)
                tea_unsup_probs = F.softmax(tea_unsup_preds[idx], dim=-1)

                unsup_loss = calc_mse_loss(stu_unsup_probs, tea_unsup_probs,
                                           args.batch_size)
                unsup_loss_vals[idx] += unsup_loss.item() / args.iter_size
                total_mse_loss += unsup_loss * unsup_weights[idx]

            total_mse_loss = total_mse_loss / args.iter_size

            # As the requires_grad is set to False in the discriminator, the
            # gradients are only accumulated in the generator, the target
            # student network is optimized to make the outputs of target domain
            # images close to the outputs of source domain images
            stu_unsup_preds = list(student_net(tgt_images0))
            d_outs, total_adv_loss = [], 0
            for idx in range(n_discriminators):
                stu_unsup_interp_pred = (stu_unsup_preds[idx])
                d_outs.append(discriminators[idx](stu_unsup_interp_pred))
                label_size = d_outs[idx].data.size()
                labels = torch.FloatTensor(label_size).fill_(source_label)
                labels = Variable(labels).cuda(args.gpu)
                adv_tgt_loss = calc_bce_loss(d_outs[idx], labels)

                total_adv_loss += lambda_adv_tgts[idx] * adv_tgt_loss
                adv_tgt_loss_vals[idx] += adv_tgt_loss.item() / args.iter_size

            total_adv_loss = total_adv_loss / args.iter_size

            # requires_grad is set to True in the discriminator,  we only
            # accumulate gradients in the discriminators, the discriminators are
            # optimized to make true predictions
            d_losses = []
            for idx in range(n_discriminators):
                discriminator = discriminators[idx]
                for param in discriminator.parameters():
                    param.requires_grad = True

                sup_preds[idx] = sup_preds[idx].detach()
                d_outs[idx] = discriminators[idx](sup_preds[idx])

                label_size = d_outs[idx].data.size()
                labels = torch.FloatTensor(label_size).fill_(source_label)
                labels = Variable(labels).cuda(args.gpu)

                d_losses.append(calc_bce_loss(d_outs[idx], labels))
                d_losses[idx] = d_losses[idx] / args.iter_size / 2
                d_losses[idx].backward()
                d_loss_vals[idx] += d_losses[idx].item()

            for idx in range(n_discriminators):
                stu_unsup_preds[idx] = stu_unsup_preds[idx].detach()
                d_outs[idx] = discriminators[idx](stu_unsup_preds[idx])

                label_size = d_outs[idx].data.size()
                labels = torch.FloatTensor(label_size).fill_(tgt_label)
                labels = Variable(labels).cuda(args.gpu)

                d_losses[idx] = calc_bce_loss(d_outs[idx], labels)
                d_losses[idx] = d_losses[idx] / args.iter_size / 2
                d_losses[idx].backward()
                d_loss_vals[idx] += d_losses[idx].item()

        for d_optimizer in d_optimizers:
            d_optimizer.step()

        total_loss = total_seg_loss + total_adv_loss + total_mse_loss
        total_loss.backward()
        student_optimizer.step()
        teacher_optimizer.step()

        log_str = 'iter = {0:7d}/{1:7d}'.format(i_iter, args.num_steps)
        log_str += ', total_seg_loss = {0:.3f} '.format(total_seg_loss)
        templ = 'seg_losses = [' + ', '.join(['%.2f'] * len(seg_loss_vals))
        log_str += templ % tuple(seg_loss_vals) + '] '
        templ = 'ens_losses = [' + ', '.join(['%.5f'] * len(unsup_loss_vals))
        log_str += templ % tuple(unsup_loss_vals) + '] '
        templ = 'adv_losses = [' + ', '.join(['%.2f'] * len(adv_tgt_loss_vals))
        log_str += templ % tuple(adv_tgt_loss_vals) + '] '
        templ = 'd_losses = [' + ', '.join(['%.2f'] * len(d_loss_vals))
        log_str += templ % tuple(d_loss_vals) + '] '

        print(log_str)
        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            filename = 'UNet' + str(
                args.num_steps_stop) + '_v18_weightedclass.pth'
            torch.save(teacher_net.cpu().state_dict(),
                       os.path.join(args.snapshot_dir, filename))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            filename = 'UNet' + str(i_iter) + '_v18_weightedclass.pth'
            torch.save(teacher_net.cpu().state_dict(),
                       os.path.join(args.snapshot_dir, filename))
            teacher_net.cuda(args.gpu)
예제 #8
0
class adda_trainer(object):
    def __init__(self, args, nnclass):
        self.target_model = None
        self.target_optim = None
        self.target_criterion = None
        self.batch_size = args.batch_size
        self.nnclass = nnclass
        self.init_target(args)
        self.init_discriminator(args)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      lr_step=40,
                                      iters_per_epoch=100)
        self.disc_params = [{
            'params': self.disc_model.parameters(),
            'lr': args.lr * 5
        }]
        self.dda_optim = torch.optim.Adam(self.train_params)
        self.discriminator_optim = torch.optim.Adam(self.disc_params)
        #self.dda_optim = torch.optim.SGD(self.train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        #self.discriminator_optim = torch.optim.SGD(self.disc_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        self.adv_aug = FastGradientSignUntargeted(self.target_model,
                                                  0.0157,
                                                  0.00784,
                                                  min_val=0,
                                                  max_val=1,
                                                  max_iters=2,
                                                  _type='linf')

    def init_target(self, args):

        self.target_model = DeepLab(num_classes=self.nnclass,
                                    backbone='resnet',
                                    output_stride=16,
                                    sync_bn=None,
                                    freeze_bn=False)
        self.train_params = [{
            'params': self.target_model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': self.target_model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]
        self.target_model = torch.nn.DataParallel(self.target_model)
        self.target_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(
                mode='bce')  #torch.nn.BCELoss(reduce ='mean')
        patch_replication_callback(self.target_model)
        model_dict = self.target_model.module.state_dict()
        checkpoint = torch.load(args.resume)
        pretrained_dict = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }
        #pretrained_dict = {k:v for k,v in checkpoint['state_dict'].items() if 'last_conv' not in k }
        model_dict.update(pretrained_dict)
        self.target_model.module.load_state_dict(model_dict)
        self.target_model = self.target_model.cuda()
        return

    def init_discriminator(self, args):
        # init D
        self.disc_model = FCDiscriminator(num_classes=2).cuda()
        self.interp = nn.Upsample(size=400, mode='bilinear')
        self.disc_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        return

    def update_weights(self, input_, src_labels, target, tgt_labels, lamda_g,
                       trainmodel):

        self.dda_optim.zero_grad()
        self.discriminator_optim.zero_grad()
        if trainmodel == 'train_gen':
            for param in self.target_model.parameters():
                param.requires_grad = True
            for param in self.disc_model.parameters():
                param.requires_grad = False
            self.disc_model.eval()
            self.target_model.train()
        else:
            for param in self.target_model.parameters():
                param.requires_grad = False
            for param in self.disc_model.parameters():
                param.requires_grad = True
            self.disc_model.train()
            self.target_model.eval()
        #tot_input = torch.cat([input_, target])
        #import pdb
        #pdb.set_trace()
        src_out, source_feature = self.target_model(input_)
        seg_loss = self.target_criterion(src_out, src_labels)
        #print(target.shape)
        targ_out, target_feature = self.target_model(target)

        # discriminator
        discriminator_x = torch.cat([source_feature, target_feature]).squeeze()
        discriminator_adv_logit = torch.cat([
            torch.zeros(source_feature.shape),
            torch.ones(target_feature.shape)
        ])
        discriminator_real_logit = torch.cat([
            torch.ones(source_feature.shape),
            torch.zeros(target_feature.shape)
        ])
        disc_out = self.disc_model(discriminator_x)
        #print(source_feature.shape, input_.shape,discriminator_adv_logit.shape, disc_out.shape)
        adv_loss = self.target_criterion(
            disc_out, discriminator_adv_logit[:, 0, :, :].cuda())
        adv_loss += self.target_criterion(
            disc_out, discriminator_adv_logit[:, 1, :, :].cuda())
        disc_loss = self.disc_criterion(
            disc_out, discriminator_real_logit[:, 0, :, :].cuda())
        disc_loss += self.disc_criterion(
            disc_out, discriminator_real_logit[:, 1, :, :].cuda())
        if trainmodel == 'train_gen':
            loss_seg = seg_loss + lamda_g * adv_loss
            loss_seg.backward()
            self.dda_optim.step()
        else:
            disc_loss.backward()
            self.discriminator_optim.step()
        tgt_loss = self.target_criterion(targ_out, tgt_labels)
        return seg_loss.data.cpu().numpy(), tgt_loss.data.cpu().numpy()
예제 #9
0
    def __init__(self, opt, logger, isTrain=True):
        self.opt = opt
        self.class_numbers = opt.n_class
        self.logger = logger
        self.best_iou = -100
        self.nets = []
        self.nets_DP = []
        self.default_gpu = 0
        self.objective_vectors = torch.zeros([self.class_numbers, 256])
        self.objective_vectors_num = torch.zeros([self.class_numbers])

        if opt.bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        elif opt.bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        else:
            raise NotImplementedError('batch norm choice {} is not implemented'.format(opt.bn))

        if self.opt.no_resume:
            restore_from = None
        else:
            restore_from= opt.resume_path
            self.best_iou = 0
        if self.opt.student_init == 'imagenet':
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from)
        elif self.opt.student_init == 'simclr':
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, 
                initialization=os.path.join(opt.root, 'Code/ProDA', 'pretrained/simclr/r101_1x_sk0.pth'), bn_clr=opt.bn_clr)
        else:
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from)
            
        logger.info('the backbone is {}'.format(opt.model_name))

        self.nets.extend([self.BaseNet])

        self.optimizers = []
        self.schedulers = []        
        optimizer_cls = torch.optim.SGD
        optimizer_params = {'lr':opt.lr, 'weight_decay':2e-4, 'momentum':0.9}

        if self.opt.stage == 'warm_up':
            self.net_D = FCDiscriminator(inplanes=self.class_numbers)
            self.net_D_DP = self.init_device(self.net_D, gpu_id=self.default_gpu, whether_DP=True)
            self.nets.extend([self.net_D])
            self.nets_DP.append(self.net_D_DP)

            self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=1e-4, betas=(0.9, 0.99))
            self.optimizers.extend([self.optimizer_D])
            self.DSchedule = get_scheduler(self.optimizer_D, opt)
            self.schedulers.extend([self.DSchedule])

        if self.opt.finetune or self.opt.stage == 'warm_up':
            self.BaseOpti = optimizer_cls([{'params':self.BaseNet.get_1x_lr_params(), 'lr':optimizer_params['lr']},
                                           {'params':self.BaseNet.get_10x_lr_params(), 'lr':optimizer_params['lr']*10}], **optimizer_params)
        else:
            self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params)
        self.optimizers.extend([self.BaseOpti])

        self.BaseSchedule = get_scheduler(self.BaseOpti, opt)
        self.schedulers.extend([self.BaseSchedule])

        if self.opt.ema:
            self.BaseNet_ema = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, bn_clr=opt.ema_bn)
            self.BaseNet_ema.load_state_dict(self.BaseNet.state_dict().copy())

        if self.opt.distillation > 0:
            self.teacher = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=opt.resume_path, bn_clr=opt.ema_bn)
            self.teacher.eval()
            self.teacher_DP = self.init_device(self.teacher, gpu_id=self.default_gpu, whether_DP=True)


        self.adv_source_label = 0
        self.adv_target_label = 1
        if self.opt.gan == 'Vanilla':
            self.bceloss = nn.BCEWithLogitsLoss(size_average=True)
        elif self.opt.gan == 'LS':
            self.bceloss = torch.nn.MSELoss()
        self.feat_prototype_distance_DP = self.init_device(feat_prototype_distance_module(), gpu_id=self.default_gpu, whether_DP=True)

        self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True)
        self.nets_DP.append(self.BaseNet_DP)
        if self.opt.ema:
            self.BaseNet_ema_DP = self.init_device(self.BaseNet_ema, gpu_id=self.default_gpu, whether_DP=True)
예제 #10
0
class CustomModel():
    def __init__(self, opt, logger, isTrain=True):
        self.opt = opt
        self.class_numbers = opt.n_class
        self.logger = logger
        self.best_iou = -100
        self.nets = []
        self.nets_DP = []
        self.default_gpu = 0
        self.objective_vectors = torch.zeros([self.class_numbers, 256])
        self.objective_vectors_num = torch.zeros([self.class_numbers])

        if opt.bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        elif opt.bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        else:
            raise NotImplementedError('batch norm choice {} is not implemented'.format(opt.bn))

        if self.opt.no_resume:
            restore_from = None
        else:
            restore_from= opt.resume_path
            self.best_iou = 0
        if self.opt.student_init == 'imagenet':
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from)
        elif self.opt.student_init == 'simclr':
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, 
                initialization=os.path.join(opt.root, 'Code/ProDA', 'pretrained/simclr/r101_1x_sk0.pth'), bn_clr=opt.bn_clr)
        else:
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from)
            
        logger.info('the backbone is {}'.format(opt.model_name))

        self.nets.extend([self.BaseNet])

        self.optimizers = []
        self.schedulers = []        
        optimizer_cls = torch.optim.SGD
        optimizer_params = {'lr':opt.lr, 'weight_decay':2e-4, 'momentum':0.9}

        if self.opt.stage == 'warm_up':
            self.net_D = FCDiscriminator(inplanes=self.class_numbers)
            self.net_D_DP = self.init_device(self.net_D, gpu_id=self.default_gpu, whether_DP=True)
            self.nets.extend([self.net_D])
            self.nets_DP.append(self.net_D_DP)

            self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=1e-4, betas=(0.9, 0.99))
            self.optimizers.extend([self.optimizer_D])
            self.DSchedule = get_scheduler(self.optimizer_D, opt)
            self.schedulers.extend([self.DSchedule])

        if self.opt.finetune or self.opt.stage == 'warm_up':
            self.BaseOpti = optimizer_cls([{'params':self.BaseNet.get_1x_lr_params(), 'lr':optimizer_params['lr']},
                                           {'params':self.BaseNet.get_10x_lr_params(), 'lr':optimizer_params['lr']*10}], **optimizer_params)
        else:
            self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params)
        self.optimizers.extend([self.BaseOpti])

        self.BaseSchedule = get_scheduler(self.BaseOpti, opt)
        self.schedulers.extend([self.BaseSchedule])

        if self.opt.ema:
            self.BaseNet_ema = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, bn_clr=opt.ema_bn)
            self.BaseNet_ema.load_state_dict(self.BaseNet.state_dict().copy())

        if self.opt.distillation > 0:
            self.teacher = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=opt.resume_path, bn_clr=opt.ema_bn)
            self.teacher.eval()
            self.teacher_DP = self.init_device(self.teacher, gpu_id=self.default_gpu, whether_DP=True)


        self.adv_source_label = 0
        self.adv_target_label = 1
        if self.opt.gan == 'Vanilla':
            self.bceloss = nn.BCEWithLogitsLoss(size_average=True)
        elif self.opt.gan == 'LS':
            self.bceloss = torch.nn.MSELoss()
        self.feat_prototype_distance_DP = self.init_device(feat_prototype_distance_module(), gpu_id=self.default_gpu, whether_DP=True)

        self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True)
        self.nets_DP.append(self.BaseNet_DP)
        if self.opt.ema:
            self.BaseNet_ema_DP = self.init_device(self.BaseNet_ema, gpu_id=self.default_gpu, whether_DP=True)

    def calculate_mean_vector(self, feat_cls, outputs, labels=None, thresh=None):
        outputs_softmax = F.softmax(outputs, dim=1)
        if thresh is None:
            thresh = -1
        conf = outputs_softmax.max(dim=1, keepdim=True)[0]
        mask = conf.ge(thresh)
        outputs_argmax = outputs_softmax.argmax(dim=1, keepdim=True)
        outputs_argmax = self.process_label(outputs_argmax.float())
        if labels is None:
            outputs_pred = outputs_argmax
        else:
            labels_expanded = self.process_label(labels)
            outputs_pred = labels_expanded * outputs_argmax
        scale_factor = F.adaptive_avg_pool2d(outputs_pred * mask, 1)
        vectors = []
        ids = []
        for n in range(feat_cls.size()[0]):
            for t in range(self.class_numbers):
                if scale_factor[n][t].item()==0:
                    continue
                if (outputs_pred[n][t] > 0).sum() < 10:
                    continue
                s = feat_cls[n] * outputs_pred[n][t] * mask[n]
                # scale = torch.sum(outputs_pred[n][t]) / labels.shape[2] / labels.shape[3] * 2
                # s = normalisation_pooling()(s, scale)
                s = F.adaptive_avg_pool2d(s, 1) / scale_factor[n][t]
                vectors.append(s)
                ids.append(t)
        return vectors, ids

    def step_adv(self, source_x, source_label, target_x, source_imageS, source_params):
        for param in self.net_D.parameters():
            param.requires_grad = False
        self.BaseOpti.zero_grad()
        
        if self.opt.S_pseudo_src > 0:
            source_output = self.BaseNet_DP(source_imageS)
            source_label_d4 = F.interpolate(source_label.unsqueeze(1).float(), size=source_output['out'].size()[2:])
            source_labelS = self.label_strong_T(source_label_d4.clone().float(), source_params, padding=250, scale=4).to(torch.int64)
            loss_ = cross_entropy2d(input=source_output['out'], target=source_labelS.squeeze(1))
            loss_GTA = loss_ * self.opt.S_pseudo_src
            source_outputUp = F.interpolate(source_output['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)
        else:
            source_output = self.BaseNet_DP(source_x, ssl=True)
            source_outputUp = F.interpolate(source_output['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)

            loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label, size_average=True, reduction='mean')

        target_output = self.BaseNet_DP(target_x, ssl=True)
        target_outputUp = F.interpolate(target_output['out'], size=target_x.size()[2:], mode='bilinear', align_corners=True)
        target_D_out = self.net_D_DP(F.softmax(target_outputUp, dim=1))
        loss_adv_G = self.bceloss(target_D_out, torch.FloatTensor(target_D_out.data.size()).fill_(self.adv_source_label).to(target_D_out.device)) * self.opt.adv
        loss_G = loss_adv_G + loss_GTA
        loss_G.backward()
        self.BaseOpti.step()

        for param in self.net_D.parameters():
            param.requires_grad = True
        self.optimizer_D.zero_grad()
        source_D_out = self.net_D_DP(F.softmax(source_outputUp.detach(), dim=1))
        target_D_out = self.net_D_DP(F.softmax(target_outputUp.detach(), dim=1))
        loss_D = self.bceloss(source_D_out, torch.FloatTensor(source_D_out.data.size()).fill_(self.adv_source_label).to(source_D_out.device)) + \
                    self.bceloss(target_D_out, torch.FloatTensor(target_D_out.data.size()).fill_(self.adv_target_label).to(target_D_out.device))
        loss_D.backward()
        self.optimizer_D.step()

        return loss_GTA.item(), loss_adv_G.item(), loss_D.item()

    def step(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None, 
            target_lpsoft=None, target_image_full=None, target_weak_params=None):

        source_out = self.BaseNet_DP(source_x, ssl=True)
        source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)

        loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)
        loss_GTA.backward()        

        if self.opt.proto_rectify:
            threshold_arg = F.interpolate(target_lpsoft, scale_factor=0.25, mode='bilinear', align_corners=True)
        else:
            threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()

        if self.opt.ema:
            ema_input = target_image_full
            with torch.no_grad():
                ema_out = self.BaseNet_ema_DP(ema_input)
            ema_out['feat'] = F.interpolate(ema_out['feat'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)
            ema_out['out'] = F.interpolate(ema_out['out'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)

        target_out = self.BaseNet_DP(target_imageS) if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_x)
        target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        target_out['feat'] = F.interpolate(target_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

        loss = torch.Tensor([0]).to(self.default_gpu)
        batch, _, w, h = threshold_arg.shape
        if self.opt.proto_rectify:
            weights = self.get_prototype_weight(ema_out['feat'], target_weak_params=target_weak_params)
            rectified = weights * threshold_arg
            threshold_arg = rectified.max(1, keepdim=True)[1]
            rectified = rectified / rectified.sum(1, keepdim=True)
            argmax = rectified.max(1, keepdim=True)[0]
            threshold_arg[argmax < self.opt.train_thred] = 250
        if self.opt.S_pseudo > 0:
            threshold_argS = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            cluster_argS = self.label_strong_T(cluster_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            threshold_arg = threshold_argS

        loss_CTS = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]))

        if self.opt.rce:
            rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
            loss_CTS = self.opt.rce_alpha * loss_CTS + self.opt.rce_beta * rce

        if self.opt.regular_w > 0:
            regular_loss = self.regular_loss(target_out['out'])
            loss_CTS = loss_CTS + regular_loss * self.opt.regular_w

        cluster_argS = None
        loss_consist = torch.Tensor([0]).to(self.default_gpu)
        if self.opt.proto_consistW > 0:
            ema2weak_feat = self.full2weak(ema_out['feat'], target_weak_params)         #N*256*H*W
            ema2weak_feat_proto_distance = self.feat_prototype_distance(ema2weak_feat)  #N*19*H*W
            ema2strong_feat_proto_distance = self.label_strong_T(ema2weak_feat_proto_distance, target_params, padding=250, scale=4)
            mask = (ema2strong_feat_proto_distance != 250).float()
            teacher = F.softmax(-ema2strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

            targetS_out = target_out if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_imageS)
            targetS_out['out'] = F.interpolate(targetS_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
            targetS_out['feat'] = F.interpolate(targetS_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

            prototype_tmp = self.objective_vectors.expand(4, -1, -1)  #gpu memory limitation
            strong_feat_proto_distance = self.feat_prototype_distance_DP(targetS_out['feat'], prototype_tmp, self.class_numbers)
            student = F.log_softmax(-strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

            loss_consist = F.kl_div(student, teacher, reduction='none')
            loss_consist = (loss_consist * mask).sum() / mask.sum()
            loss = loss + self.opt.proto_consistW * loss_consist

        loss = loss + loss_CTS
        loss.backward()
        self.BaseOpti.step()
        self.BaseOpti.zero_grad()

        if self.opt.moving_prototype: #update prototype
            ema_vectors, ema_ids = self.calculate_mean_vector(ema_out['feat'].detach(), ema_out['out'].detach())
            for t in range(len(ema_ids)):
                self.update_objective_SingleVector(ema_ids[t], ema_vectors[t].detach(), start_mean=False)
        
        if self.opt.ema: #update ema model
            for param_q, param_k in zip(self.BaseNet.parameters(), self.BaseNet_ema.parameters()):
                param_k.data = param_k.data.clone() * 0.999 + param_q.data.clone() * (1. - 0.999)
            for buffer_q, buffer_k in zip(self.BaseNet.buffers(), self.BaseNet_ema.buffers()):
                buffer_k.data = buffer_q.data.clone()

        return loss.item(), loss_CTS.item(), loss_consist.item()

    def regular_loss(self, activation):
        logp = F.log_softmax(activation, dim=1)
        if self.opt.regular_type == 'MRENT':
            p = F.softmax(activation, dim=1)
            loss = (p * logp).sum() / (p.shape[0]*p.shape[2]*p.shape[3])
        elif self.opt.regular_type == 'MRKLD':
            loss = - logp.sum() / (logp.shape[0]*logp.shape[1]*logp.shape[2]*logp.shape[3])
        return loss

    def rce(self, pred, labels):
        pred = F.softmax(pred, dim=1)
        pred = torch.clamp(pred, min=1e-7, max=1.0)
        mask = (labels != 250).float()
        labels[labels==250] = self.class_numbers
        label_one_hot = torch.nn.functional.one_hot(labels, self.class_numbers + 1).float().to(self.default_gpu)
        label_one_hot = torch.clamp(label_one_hot.permute(0,3,1,2)[:,:-1,:,:], min=1e-4, max=1.0)
        rce = -(torch.sum(pred * torch.log(label_one_hot), dim=1) * mask).sum() / (mask.sum() + 1e-6)
        return rce

    def step_distillation(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None):

        source_out = self.BaseNet_DP(source_x, ssl=True)
        source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)
        loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)
        loss_GTA.backward()

        threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()
        if self.opt.S_pseudo > 0:
            threshold_arg = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            target_out = self.BaseNet_DP(target_imageS)
        else:
            target_out = self.BaseNet_DP(target_x)
        target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        batch, _, w, h = threshold_arg.shape
        loss = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]), size_average=True, reduction='mean')
        if self.opt.rce:
            rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
            loss = self.opt.rce_alpha * loss + self.opt.rce_beta * rce

        if self.opt.distillation > 0:
            student = F.softmax(target_out['out'], dim=1)
            with torch.no_grad():
                teacher_out = self.teacher_DP(target_imageS)
                teacher_out['out'] = F.interpolate(teacher_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
                teacher = F.softmax(teacher_out['out'], dim=1)

            loss_kd = F.kl_div(student, teacher, reduction='none')
            mask = (teacher != 250).float()
            loss_kd = (loss_kd * mask).sum() / mask.sum()
            loss = loss + self.opt.distillation * loss_kd

        loss.backward()
        self.BaseOpti.step()
        self.BaseOpti.zero_grad()
        return loss_GTA.item(), loss.item()

    def full2weak(self, feat, target_weak_params):
        tmp = []
        for i in range(feat.shape[0]):
            h, w = target_weak_params['RandomSized'][0][i], target_weak_params['RandomSized'][1][i]
            feat_ = F.interpolate(feat[i:i+1], size=[int(h/4), int(w/4)], mode='bilinear', align_corners=True)
            y1, y2, x1, x2 = target_weak_params['RandomCrop'][0][i], target_weak_params['RandomCrop'][1][i], target_weak_params['RandomCrop'][2][i], target_weak_params['RandomCrop'][3][i]
            y1, th, x1, tw = int(y1/4), int((y2-y1)/4), int(x1/4), int((x2-x1)/4)
            feat_ = feat_[:, :, y1:y1+th, x1:x1+tw]
            if target_weak_params['RandomHorizontallyFlip'][i]:
                inv_idx = torch.arange(feat_.size(3)-1,-1,-1).long().to(feat_.device)
                feat_ = feat_.index_select(3,inv_idx)
            tmp.append(feat_)
        feat = torch.cat(tmp, 0)
        return feat

    def feat_prototype_distance(self, feat):
        N, C, H, W = feat.shape
        feat_proto_distance = -torch.ones((N, self.class_numbers, H, W)).to(feat.device)
        for i in range(self.class_numbers):
            #feat_proto_distance[:, i, :, :] = torch.norm(torch.Tensor(self.objective_vectors[i]).reshape(-1,1,1).expand(-1, H, W).to(feat.device) - feat, 2, dim=1,)
            feat_proto_distance[:, i, :, :] = torch.norm(self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W) - feat, 2, dim=1,)
        return feat_proto_distance

    def get_prototype_weight(self, feat, label=None, target_weak_params=None):
        feat = self.full2weak(feat, target_weak_params)
        feat_proto_distance = self.feat_prototype_distance(feat)
        feat_nearest_proto_distance, feat_nearest_proto = feat_proto_distance.min(dim=1, keepdim=True)

        feat_proto_distance = feat_proto_distance - feat_nearest_proto_distance
        weight = F.softmax(-feat_proto_distance * self.opt.proto_temperature, dim=1)
        return weight

    def label_strong_T(self, label, params, padding, scale=1):
        label = label + 1
        for i in range(label.shape[0]):
            for (Tform, param) in params.items():
                if Tform == 'Hflip' and param[i].item() == 1:
                    label[i] = label[i].clone().flip(-1)
                elif (Tform == 'ShearX' or Tform == 'ShearY' or Tform == 'TranslateX' or Tform == 'TranslateY' or Tform == 'Rotate') and param[i].item() != 1e4:
                    v = int(param[i].item() // scale) if Tform == 'TranslateX' or Tform == 'TranslateY' else param[i].item()
                    label[i:i+1] = affine_sample(label[i:i+1].clone(), v, Tform)
                elif Tform == 'CutoutAbs' and isinstance(param, list):
                    x0 = int(param[0][i].item() // scale)
                    y0 = int(param[1][i].item() // scale)
                    x1 = int(param[2][i].item() // scale)
                    y1 = int(param[3][i].item() // scale)
                    label[i, :, y0:y1, x0:x1] = 0
        label[label == 0] = padding + 1  # for strong augmentation, constant padding
        label = label - 1
        return label

    def process_label(self, label):
        batch, channel, w, h = label.size()
        pred1 = torch.zeros(batch, self.class_numbers + 1, w, h).to(self.default_gpu)
        id = torch.where(label < self.class_numbers, label, torch.Tensor([self.class_numbers]).to(self.default_gpu))
        pred1 = pred1.scatter_(1, id.long(), 1)
        return pred1

    def freeze_bn_apply(self):
        for net in self.nets:
            net.apply(freeze_bn)
        for net in self.nets_DP:
            net.apply(freeze_bn)

    def scheduler_step(self):
        for scheduler in self.schedulers:
            scheduler.step()
    
    def optimizer_zerograd(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()
    

    def init_device(self, net, gpu_id=None, whether_DP=False):
        gpu_id = gpu_id or self.default_gpu
        device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else 'cpu')
        net = net.to(device)
        # if torch.cuda.is_available():
        if whether_DP:
            #net = DataParallelWithCallback(net, device_ids=[0])
            net = DataParallelWithCallback(net, device_ids=range(torch.cuda.device_count()))
        return net
    
    def eval(self, net=None, logger=None):
        """Make specific models eval mode during test time"""
        # if issubclass(net, nn.Module) or issubclass(net, BaseModel):
        if net == None:
            for net in self.nets:
                net.eval()
            for net in self.nets_DP:
                net.eval()
            if logger!=None:    
                logger.info("Successfully set the model eval mode") 
        else:
            net.eval()
            if logger!=None:    
                logger("Successfully set {} eval mode".format(net.__class__.__name__))
        return

    def train(self, net=None, logger=None):
        if net==None:
            for net in self.nets:
                net.train()
            for net in self.nets_DP:
                net.train()
        else:
            net.train()
        return

    def update_objective_SingleVector(self, id, vector, name='moving_average', start_mean=True):
        if vector.sum().item() == 0:
            return
        if start_mean and self.objective_vectors_num[id].item() < 100:
            name = 'mean'
        if name == 'moving_average':
            self.objective_vectors[id] = self.objective_vectors[id] * (1 - self.opt.proto_momentum) + self.opt.proto_momentum * vector.squeeze()
            self.objective_vectors_num[id] += 1
            self.objective_vectors_num[id] = min(self.objective_vectors_num[id], 3000)
        elif name == 'mean':
            self.objective_vectors[id] = self.objective_vectors[id] * self.objective_vectors_num[id] + vector.squeeze()
            self.objective_vectors_num[id] += 1
            self.objective_vectors[id] = self.objective_vectors[id] / self.objective_vectors_num[id]
            self.objective_vectors_num[id] = min(self.objective_vectors_num[id], 3000)
            pass
        else:
            raise NotImplementedError('no such updating way of objective vectors {}'.format(name))