def init_discriminator(self, args): # init D self.discriminator_model = FCDiscriminator(num_classes=2).cuda() self.interp = nn.Upsample(size=400, mode='bilinear') self.disc_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) return
def load_models(mode, device, args): """ :param mode: "SS" or "Discriminator" :param args: :return: """ if mode == "SS": if args.network == "segnet_small": from models.SegNet import SegNet_Small model = SegNet_Small(args.channels, args.classes, args.skip_type, BR_bool=args.BR, separable_conv=args.SC) model = model.to(device) summary(model, (args.channels, args.image_size, args.image_size), args) elif mode == "Discriminator": from models.discriminator import FCDiscriminator model = FCDiscriminator(num_classes=NUM_CLASSES, ) else: raise ValueError("Invalid mode {}!".format(mode)) try: if args.checkpoint_SS: model.load_state_dict(torch.load(args.checkpoint_SS)) if args.checkpoint_DNet: model.load_state_dict(torch.load(args.checkpoint_DNet)) except Exception as e: print(e) sys.exit(0) return model
def __init__(self, cfg, writer, logger, use_pseudo_label=False, modal_num=3, multimodal_merger=multimodal_merger): self.cfg = cfg self.writer = writer self.class_numbers = 19 self.logger = logger cfg_model = cfg['model'] self.cfg_model = cfg_model self.best_iou = -100 self.iter = 0 self.nets = [] self.split_gpu = 0 self.default_gpu = cfg['model']['default_gpu'] self.PredNet_Dir = None self.valid_classes = cfg['training']['valid_classes'] self.G_train = True self.cls_feature_weight = cfg['training']['cls_feature_weight'] self.use_pseudo_label = use_pseudo_label self.modal_num = modal_num # cluster vectors & cuda initialization self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda() self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda() self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda() self.class_threshold_group = torch.full([self.modal_num + 1, 19], 0.6).cuda() self.disc_T = torch.FloatTensor([0.0]).cuda() #self.metrics = CustomMetrics(self.class_numbers) self.metrics = CustomMetrics(self.class_numbers, modal_num=self.modal_num, model=self) # multimodal / multi-branch merger self.multimodal_merger = multimodal_merger bn = cfg_model['bn'] if bn == 'sync_bn': BatchNorm = SynchronizedBatchNorm2d elif bn == 'bn': BatchNorm = nn.BatchNorm2d elif bn == 'gn': BatchNorm = nn.GroupNorm else: raise NotImplementedError('batch norm choice {} is not implemented'.format(bn)) if True: self.PredNet = DeepLab( num_classes=19, backbone=cfg_model['basenet']['version'], output_stride=16, bn=cfg_model['bn'], freeze_bn=True, modal_num=self.modal_num ).cuda() self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet) self.PredNet_DP = self.init_device(self.PredNet, gpu_id=self.default_gpu, whether_DP=True) self.PredNet.eval() self.PredNet_num = 0 self.PredDnet = FCDiscriminator(inplanes=19) self.load_PredDnet(cfg, writer, logger, dir=None, net=self.PredDnet) self.PredDnet_DP = self.init_device(self.PredDnet, gpu_id=self.default_gpu, whether_DP=True) self.PredDnet.eval() self.BaseNet = DeepLab( num_classes=19, backbone=cfg_model['basenet']['version'], output_stride=16, bn=cfg_model['bn'], freeze_bn=True, modal_num=self.modal_num ) logger.info('the backbone is {}'.format(cfg_model['basenet']['version'])) self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True) self.nets.extend([self.BaseNet]) self.nets_DP = [self.BaseNet_DP] # Discriminator self.SOURCE_LABEL = 0 self.TARGET_LABEL = 1 self.DNets = [] self.DNets_DP = [] for _ in range(self.modal_num+1): _net_d = FCDiscriminator(inplanes=19) self.DNets.append(_net_d) _net_d_DP = self.init_device(_net_d, gpu_id=self.default_gpu, whether_DP=True) self.DNets_DP.append(_net_d_DP) self.nets.extend(self.DNets) self.nets_DP.extend(self.DNets_DP) self.optimizers = [] self.schedulers = [] optimizer_cls = torch.optim.SGD optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'} optimizer_cls_D = torch.optim.Adam optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items() if k != 'name'} if False: self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params) else: self.BaseOpti = optimizer_cls(self.BaseNet.optim_parameters(cfg['training']['optimizer']['lr']), **optimizer_params) self.optimizers.extend([self.BaseOpti]) self.DiscOptis = [] for _d_net in self.DNets: self.DiscOptis.append( optimizer_cls_D(_d_net.parameters(), **optimizer_params_D) ) self.optimizers.extend(self.DiscOptis) self.schedulers = [] if False: self.BaseSchedule = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule']) self.schedulers.extend([self.BaseSchedule]) else: """BaseSchedule detail see FUNC: scheduler_step()""" self.learning_rate = cfg['training']['optimizer']['lr'] self.gamma = cfg['training']['lr_schedule']['gamma'] self.num_steps = cfg['training']['lr_schedule']['max_iter'] self._BaseSchedule_nouse = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule']) self.schedulers.extend([self._BaseSchedule_nouse]) self.DiscSchedules = [] for _disc_opt in self.DiscOptis: self.DiscSchedules.append( get_scheduler(_disc_opt, cfg['training']['lr_schedule']) ) self.schedulers.extend(self.DiscSchedules) self.setup(cfg, writer, logger) self.adv_source_label = 0 self.adv_target_label = 1 self.bceloss = nn.BCEWithLogitsLoss(reduce=False) self.loss_fn = get_loss_function(cfg) pseudo_cfg = copy.deepcopy(cfg) pseudo_cfg['training']['loss']['name'] = 'cross_entropy4d' self.pseudo_loss_fn = get_loss_function(pseudo_cfg) self.mseloss = nn.MSELoss() self.l1loss = nn.L1Loss() self.smoothloss = nn.SmoothL1Loss() self.triplet_loss = nn.TripletMarginLoss() self.kl_distance = nn.KLDivLoss(reduction='none')
class CustomModel(): def __init__(self, cfg, writer, logger, use_pseudo_label=False, modal_num=3, multimodal_merger=multimodal_merger): self.cfg = cfg self.writer = writer self.class_numbers = 19 self.logger = logger cfg_model = cfg['model'] self.cfg_model = cfg_model self.best_iou = -100 self.iter = 0 self.nets = [] self.split_gpu = 0 self.default_gpu = cfg['model']['default_gpu'] self.PredNet_Dir = None self.valid_classes = cfg['training']['valid_classes'] self.G_train = True self.cls_feature_weight = cfg['training']['cls_feature_weight'] self.use_pseudo_label = use_pseudo_label self.modal_num = modal_num # cluster vectors & cuda initialization self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda() self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda() self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda() self.class_threshold_group = torch.full([self.modal_num + 1, 19], 0.6).cuda() self.disc_T = torch.FloatTensor([0.0]).cuda() #self.metrics = CustomMetrics(self.class_numbers) self.metrics = CustomMetrics(self.class_numbers, modal_num=self.modal_num, model=self) # multimodal / multi-branch merger self.multimodal_merger = multimodal_merger bn = cfg_model['bn'] if bn == 'sync_bn': BatchNorm = SynchronizedBatchNorm2d elif bn == 'bn': BatchNorm = nn.BatchNorm2d elif bn == 'gn': BatchNorm = nn.GroupNorm else: raise NotImplementedError('batch norm choice {} is not implemented'.format(bn)) if True: self.PredNet = DeepLab( num_classes=19, backbone=cfg_model['basenet']['version'], output_stride=16, bn=cfg_model['bn'], freeze_bn=True, modal_num=self.modal_num ).cuda() self.load_PredNet(cfg, writer, logger, dir=None, net=self.PredNet) self.PredNet_DP = self.init_device(self.PredNet, gpu_id=self.default_gpu, whether_DP=True) self.PredNet.eval() self.PredNet_num = 0 self.PredDnet = FCDiscriminator(inplanes=19) self.load_PredDnet(cfg, writer, logger, dir=None, net=self.PredDnet) self.PredDnet_DP = self.init_device(self.PredDnet, gpu_id=self.default_gpu, whether_DP=True) self.PredDnet.eval() self.BaseNet = DeepLab( num_classes=19, backbone=cfg_model['basenet']['version'], output_stride=16, bn=cfg_model['bn'], freeze_bn=True, modal_num=self.modal_num ) logger.info('the backbone is {}'.format(cfg_model['basenet']['version'])) self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True) self.nets.extend([self.BaseNet]) self.nets_DP = [self.BaseNet_DP] # Discriminator self.SOURCE_LABEL = 0 self.TARGET_LABEL = 1 self.DNets = [] self.DNets_DP = [] for _ in range(self.modal_num+1): _net_d = FCDiscriminator(inplanes=19) self.DNets.append(_net_d) _net_d_DP = self.init_device(_net_d, gpu_id=self.default_gpu, whether_DP=True) self.DNets_DP.append(_net_d_DP) self.nets.extend(self.DNets) self.nets_DP.extend(self.DNets_DP) self.optimizers = [] self.schedulers = [] optimizer_cls = torch.optim.SGD optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'} optimizer_cls_D = torch.optim.Adam optimizer_params_D = {k:v for k, v in cfg['training']['optimizer_D'].items() if k != 'name'} if False: self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params) else: self.BaseOpti = optimizer_cls(self.BaseNet.optim_parameters(cfg['training']['optimizer']['lr']), **optimizer_params) self.optimizers.extend([self.BaseOpti]) self.DiscOptis = [] for _d_net in self.DNets: self.DiscOptis.append( optimizer_cls_D(_d_net.parameters(), **optimizer_params_D) ) self.optimizers.extend(self.DiscOptis) self.schedulers = [] if False: self.BaseSchedule = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule']) self.schedulers.extend([self.BaseSchedule]) else: """BaseSchedule detail see FUNC: scheduler_step()""" self.learning_rate = cfg['training']['optimizer']['lr'] self.gamma = cfg['training']['lr_schedule']['gamma'] self.num_steps = cfg['training']['lr_schedule']['max_iter'] self._BaseSchedule_nouse = get_scheduler(self.BaseOpti, cfg['training']['lr_schedule']) self.schedulers.extend([self._BaseSchedule_nouse]) self.DiscSchedules = [] for _disc_opt in self.DiscOptis: self.DiscSchedules.append( get_scheduler(_disc_opt, cfg['training']['lr_schedule']) ) self.schedulers.extend(self.DiscSchedules) self.setup(cfg, writer, logger) self.adv_source_label = 0 self.adv_target_label = 1 self.bceloss = nn.BCEWithLogitsLoss(reduce=False) self.loss_fn = get_loss_function(cfg) pseudo_cfg = copy.deepcopy(cfg) pseudo_cfg['training']['loss']['name'] = 'cross_entropy4d' self.pseudo_loss_fn = get_loss_function(pseudo_cfg) self.mseloss = nn.MSELoss() self.l1loss = nn.L1Loss() self.smoothloss = nn.SmoothL1Loss() self.triplet_loss = nn.TripletMarginLoss() self.kl_distance = nn.KLDivLoss(reduction='none') def create_PredNet(self,): ss = DeepLab( num_classes=19, backbone=self.cfg_model['basenet']['version'], output_stride=16, bn=self.cfg_model['bn'], freeze_bn=True, modal_num=self.modal_num, ).cuda() ss.eval() return ss def setup(self, cfg, writer, logger): ''' set optimizer and load pretrained model ''' for net in self.nets: # name = net.__class__.__name__ self.init_weights(cfg['model']['init'], logger, net) print("Initializition completed") if hasattr(net, '_load_pretrained_model') and cfg['model']['pretrained']: print("loading pretrained model for {}".format(net.__class__.__name__)) net._load_pretrained_model() '''load pretrained model ''' if cfg['training']['resume_flag']: self.load_nets(cfg, writer, logger) pass def lr_poly(self): return self.learning_rate * ((1 - float(self.iter) / self.num_steps) ** (self.gamma)) def adjust_basenet_learning_rate(self): lr = self.lr_poly() self.BaseOpti.param_groups[0]['lr'] = lr if len(self.BaseOpti.param_groups) > 1: self.BaseOpti.param_groups[1]['lr'] = lr * 10 def forward(self, input): feat, feat_low, att_mask, feat_cls, output = self.BaseNet_DP(input) return feat, feat_low, feat_cls, output def forward_Up(self, input): feat, feat_low, feat_cls, outputs = self.forward(input) merge_out = self.multimodal_merger( { 'feat_cls': feat_cls, 'output': output, }, is_upsample=True, size=input.size()[2:], ) return feat, feat_low, merge_out['feat_cls'], merge_out['output_comb'] def PredNet_Forward(self, input): with torch.no_grad(): _, _, att_mask, feat_cls, output_result = self.PredNet_DP(input) return _, _, att_mask, feat_cls, output_result def calculate_mean_vector(self, feat_cls, outputs, labels, ): outputs_softmax = F.softmax(outputs, dim=1) outputs_argmax = outputs_softmax.argmax(dim=1, keepdim=True) outputs_argmax = self.process_label(outputs_argmax.float()) labels_expanded = self.process_label(labels) outputs_pred = labels_expanded * outputs_argmax scale_factor = F.adaptive_avg_pool2d(outputs_pred, 1) vectors = [] ids = [] for n in range(feat_cls.size()[0]): for t in range(self.class_numbers): if scale_factor[n][t].item()==0: continue if (outputs_pred[n][t] > 0).sum() < 10: continue s = feat_cls[n] * outputs_pred[n][t] scale = torch.sum(outputs_pred[n][t]) / labels.shape[2] / labels.shape[3] * 2 s = normalisation_pooling()(s, scale) s = F.adaptive_avg_pool2d(s, 1) / scale_factor[n][t] vectors.append(s) ids.append(t) return vectors, ids def step(self, source_x, source_label, source_modal_ids, target_x, target_label, target_modal_ids, use_pseudo_loss=False): assert len(source_modal_ids) == source_x.size(0), "modal_ids' batchsize != source_x's batchsize" _, _, source_feat_cls, source_output = self.forward(input=source_x) """source_output: [B x 19 x W x H, ...] select modal-branch output in each batchsize Specific-modal output """ source_output_modal_k = torch.stack( [ source_output[_modal_i][_batch_i] for _batch_i, _modal_i in enumerate(source_modal_ids) ], dim=0, ) # attention output & specific-modal output source_output_comb = torch.cat([source_output_modal_k, source_output[-1]], dim=0) source_label_comb = torch.cat([source_label, source_label.clone()], dim=0) source_outputUp = F.interpolate(source_output_comb, size=source_x.size()[-2:], mode='bilinear', align_corners=True) loss_GTA = self.loss_fn(input=source_outputUp, target=source_label_comb) self.PredNet.eval() with torch.no_grad(): _, _, att_mask, feat_cls, output = self.PredNet_Forward(target_x) threshold_args_comb, cluster_args_comb = self.metrics.update(feat_cls, output, target_label, modal_ids=[_i for _i in range(self.modal_num+1)], att_mask=att_mask) """ Discriminator-guided easy/hard training """ target_label_size = target_label.size() t_out = output[-1] _t_out = F.interpolate(t_out.detach(), size=(target_label_size[1]*4, target_label_size[2]*4), mode='bilinear', align_corners=True) _t_D_out = self.PredDnet_DP(F.softmax(_t_out)) _t_D_out_prob = F.sigmoid(_t_D_out) disc_easy_weight = torch.where(_t_D_out_prob > self.disc_T, _t_D_out_prob, torch.FloatTensor([0.0]).cuda()) disc_easy_weight = torch.where(threshold_args_comb != 250, disc_easy_weight, torch.FloatTensor([0.0]).cuda()).squeeze(1) disc_hard_mask = torch.where(_t_D_out_prob < self.disc_T, torch.Tensor([1]).cuda(), torch.Tensor([0]).cuda()) disc_hard_mask = torch.where(threshold_args_comb == 250, torch.Tensor([1]).cuda(), disc_hard_mask) loss_L2_source_cls = torch.Tensor([0]).cuda(self.split_gpu) loss_L2_target_cls = torch.Tensor([0]).cuda(self.split_gpu) _, _, target_feat_cls, target_output = self.forward(target_x) if self.cfg['training']['loss_L2_cls']: # distance loss _batch, _w, _h = source_label.shape source_label_downsampled = source_label.reshape([_batch,1,_w, _h]).float() source_label_downsampled = F.interpolate(source_label_downsampled.float(), size=source_feat_cls[0].size()[-2:], mode='nearest') #or F.softmax(input=source_output, dim=1) loss_L2_source_cls = torch.Tensor([0]).cuda() loss_L2_target_cls = torch.Tensor([0]).cuda() for _modal_i, _source_feat_i, _source_out_i, _target_feat_i, _target_out_i in zip(range(self.modal_num + 1), source_feat_cls, source_output, target_feat_cls, target_output): if _modal_i < 2: continue source_vectors, source_ids = self.calculate_mean_vector(_source_feat_i, _source_out_i, source_label_downsampled) loss_L2_source_cls += self.class_vectors_alignment(source_ids, source_vectors, modal_ids=[_modal_i,]) target_vectors, target_ids = self.calculate_mean_vector(_target_feat_i, _target_out_i, cluster_args_comb.float()) loss_L2_target_cls += self.class_vectors_alignment(target_ids, target_vectors, modal_ids=[_modal_i,]) loss_L2_cls = self.cls_feature_weight * (loss_L2_source_cls + loss_L2_target_cls) if loss_L2_cls.item() > 1.0: loss_L2_cls = loss_L2_cls / 10.0 if loss_L2_cls.item() > 0.5: loss_L2_cls = loss_L2_cls / 3.0 target_label_size = target_label.size() loss = torch.Tensor([0]).cuda() batch, _, w, h = threshold_args_comb.shape _cluster_args_comb = cluster_args_comb.reshape([batch, w, h]) _threshold_args_comb = threshold_args_comb.reshape([batch, w, h]) _target_output = target_output[-1] _loss_CTS = self.pseudo_loss_fn(input=_target_output, target=_threshold_args_comb) # CAG-based and probability-based PLA loss_CTS = torch.sum(_loss_CTS * disc_easy_weight) / (1 + (disc_easy_weight > 0).sum()) if self.G_train and self.cfg['training']['loss_pseudo_label']: loss = loss + loss_CTS if self.G_train and self.cfg['training']['loss_source_seg']: loss = loss + loss_GTA if self.cfg['training']['loss_L2_cls']: loss = loss + torch.sum(loss_L2_cls) # adversarial loss # ----------------------------- """Generator (segmentation)""" # ----------------------------- # On Source Domain loss_adv = torch.Tensor([0]).cuda() _batch_size = 0 source_modal_ids_tensor = torch.Tensor(source_modal_ids).cuda() target_modal_ids_tensor = torch.Tensor(target_modal_ids).cuda() for t_out, _d_net_DP, _d_net, modal_idx in zip(target_output, self.DNets_DP, self.DNets, range(len(target_output))): # set grad false self.set_requires_grad(self.logger, _d_net, requires_grad = False) t_D_out = _d_net_DP(F.softmax(t_out)) _disc_hard_mask = F.interpolate(disc_hard_mask, size=(t_D_out.size(2), t_D_out.size(3)), mode='nearest') #source_modal_ids loss_temp = torch.sum(self.bceloss( t_D_out, torch.FloatTensor(t_D_out.data.size()).fill_(1.0).cuda() ) * _disc_hard_mask, [1,2,3]) / (torch.sum(disc_hard_mask, [1,2,3]) + 1) if modal_idx >= self.modal_num: loss_adv += torch.mean(loss_temp) elif torch.mean(torch.as_tensor((modal_idx==target_modal_ids_tensor), dtype=torch.float32)) == 0: loss_adv += 0.0 else: loss_adv += torch.mean(torch.masked_select(loss_temp, target_modal_ids_tensor==modal_idx)) _batch_size += t_out.size(0) loss_adv *= self.cfg['training']['loss_adv_lambda'] loss_G = torch.Tensor([0]).cuda() loss_G = loss_G + loss_adv loss = loss + loss_G if loss.item() != 0: loss.backward() self.BaseOpti.step() self.BaseOpti.zero_grad() # ----------------------------- """Discriminator """ # ----------------------------- _batch_size = 0 loss_D_comb = torch.Tensor([0]).cuda() source_label_size = source_label.size() for s_out, t_out, _d_net_DP, _d_net, _disc_opt, modal_idx in zip(source_output, target_output, self.DNets_DP, self.DNets, self.DiscOptis, range(len(source_output))): self.set_requires_grad(self.logger, _d_net, requires_grad = True) _batch_size = 0 loss_D = torch.Tensor([0]).cuda() # source domain s_D_out = _d_net_DP(F.softmax(s_out.detach())) loss_temp_s = torch.mean(self.bceloss( s_D_out, torch.FloatTensor(s_D_out.data.size()).fill_(1.0).cuda() ), [1,2,3]) if modal_idx >= self.modal_num: loss_D += torch.mean(loss_temp_s) elif torch.mean(torch.as_tensor((modal_idx==source_modal_ids_tensor), dtype=torch.float32)) == 0: loss_D += 0.0 else: loss_D += torch.mean(torch.masked_select(loss_temp_s, source_modal_ids_tensor==modal_idx)) # target domain _batch_size += (s_out.size(0) + t_out.size(0)) t_D_out = _d_net_DP(F.softmax(t_out.detach())) loss_temp_t = torch.mean(self.bceloss( t_D_out, torch.FloatTensor(t_D_out.data.size()).fill_(0.0).cuda() ), [1,2,3]) if modal_idx >= self.modal_num: loss_D += torch.mean(loss_temp_t) elif torch.mean(torch.as_tensor((modal_idx==target_modal_ids_tensor), dtype=torch.float32)) == 0: loss_D += 0.0 else: loss_D += torch.mean(torch.masked_select(loss_temp_t, target_modal_ids_tensor==modal_idx)) loss_D *= self.cfg['training']['loss_adv_lambda']*0.5 if loss_D.item() != 0: loss_D.backward() _disc_opt.step() _disc_opt.zero_grad() loss_D_comb += loss_D return loss, loss_adv, loss_D_comb def process_label(self, label): batch, channel, w, h = label.size() pred1 = torch.zeros(batch, 20, w, h).cuda() id = torch.where(label < 19, label, torch.Tensor([19]).cuda()) pred1 = pred1.scatter_(1, id.long(), 1) return pred1 def class_vectors_alignment(self, ids, vectors, modal_ids=[0,]): loss = torch.Tensor([0]).cuda() """construct category objective vectors""" # objective_vectors_group 2 x 19 x 256 --> 19 x 512 _objective_vectors_set = self.metrics.multimodal_merger.merge_objective_vectors(modal_ids=modal_ids) for i in range(len(ids)): if ids[i] not in self.valid_classes: continue new_loss = self.smoothloss(vectors[i].squeeze().cuda(), _objective_vectors_set[ids[i]]) while (new_loss.item() > 5): new_loss = new_loss / 10 loss = loss + new_loss loss = loss / len(ids) * 10 return loss def freeze_bn_apply(self): for net in self.nets: net.apply(freeze_bn) for net in self.nets_DP: net.apply(freeze_bn) def scheduler_step(self): if self.use_pseudo_label: for scheduler in self.schedulers: scheduler.step() else: """skipped _BaseScheduler_nouse""" for scheduler in self.schedulers[1:]: scheduler.step() self.adjust_basenet_learning_rate() def optimizer_zerograd(self): for optimizer in self.optimizers: optimizer.zero_grad() def optimizer_step(self): for opt in self.optimizers: opt.step() def init_device(self, net, gpu_id=None, whether_DP=False): gpu_id = gpu_id or self.default_gpu device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else 'cpu') net = net.to(device) if whether_DP: net = DataParallelWithCallback(net, device_ids=range(torch.cuda.device_count())) return net def eval(self, net=None, logger=None): """Make specific models eval mode during test time""" if net == None: for net in self.nets: net.eval() for net in self.nets_DP: net.eval() if logger!=None: logger.info("Successfully set the model eval mode") else: net.eval() if logger!=None: logger("Successfully set {} eval mode".format(net.__class__.__name__)) return def train(self, net=None, logger=None): if net==None: for net in self.nets: net.train() for net in self.nets_DP: net.train() else: net.train() return def set_requires_grad(self, logger, net, requires_grad = False): """Set requires_grad=Fasle for all the networks to avoid unnecessary computations Parameters: net (BaseModel) -- the network which will be operated on requires_grad (bool) -- whether the networks require gradients or not """ for parameter in net.parameters(): parameter.requires_grad = requires_grad def set_requires_grad_layer(self, logger, net, layer_type='batchnorm', requires_grad=False): ''' set specific type of layers whether needing grad ''' # print('Warning: all the BatchNorm params are fixed!') # logger.info('Warning: all the BatchNorm params are fixed!') for net in self.nets: for _i in net.modules(): if _i.__class__.__name__.lower().find(layer_type.lower()) != -1: _i.weight.requires_grad = requires_grad return def init_weights(self, cfg, logger, net, init_type='normal', init_gain=0.02): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ init_type = cfg.get('init_type', init_type) init_gain = cfg.get('init_gain', init_gain) def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif isinstance(m, SynchronizedBatchNorm2d) or classname.find('BatchNorm2d') != -1 \ or isinstance(m, nn.GroupNorm): # or isinstance(m, InPlaceABN) or isinstance(m, InPlaceABNSync): m.weight.data.fill_(1) m.bias.data.zero_() # BatchNorm Layer's weight is not a matrix; only normal distribution applies. print('initialize {} with {}'.format(init_type, net.__class__.__name__)) logger.info('initialize {} with {}'.format(init_type, net.__class__.__name__)) net.apply(init_func) # apply the initialization function <init_func> pass def adaptive_load_nets(self, net, model_weight): model_dict = net.state_dict() pretrained_dict = {k : v for k, v in model_weight.items() if k in model_dict} print("[INFO] Pretrained dict:", pretrained_dict.keys()) model_dict.update(pretrained_dict) net.load_state_dict(model_dict) def load_nets(self, cfg, writer, logger): # load pretrained weights on the net if os.path.isfile(cfg['training']['resume']): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume']) ) checkpoint = torch.load(cfg['training']['resume']) _k = -1 net_state_no = {} for net in self.nets: name = net.__class__.__name__ if name not in net_state_no: net_state_no[name] = 0 else: net_state_no[name] += 1 _k += 1 if checkpoint.get(name) == None: continue if name.find('FCDiscriminator') != -1 and cfg['training']['gan_resume'] == False: continue #self.adaptive_load_nets(net, checkpoint[name]["model_state"]) if isinstance(checkpoint[name], list): self.adaptive_load_nets(net, checkpoint[name][net_state_no[name]]["model_state"]) else: print("*****************************************") print("[WARNING] Using depreciated load version! Model {}".format(name)) print("*****************************************") self.adaptive_load_nets(net, checkpoint[name]["model_state"]) if cfg['training']['optimizer_resume']: if isinstance(checkpoint[name], list): self.adaptive_load_nets(self.optimizers[_k], checkpoint[name][net_state_no[name]]["optimizer_state"]) self.adaptive_load_nets(self.schedulers[_k], checkpoint[name][net_state_no[name]]["scheduler_state"]) else: self.adaptive_load_nets(self.optimizers[_k], checkpoint[name]["optimizer_state"]) self.adaptive_load_nets(self.schedulers[_k], checkpoint[name]["scheduler_state"]) self.iter = checkpoint["iter"] if 'iter' in checkpoint else 0 logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg['training']['resume'], self.iter ) ) else: raise Exception("No checkpoint found at '{}'".format(cfg['training']['resume'])) def load_PredNet(self, cfg, writer, logger, dir=None, net=None): # load pretrained weights on the net dir = dir or cfg['training']['Pred_resume'] best_iou = 0 if os.path.isfile(dir): logger.info( "Loading model and optimizer from checkpoint '{}'".format(dir) ) checkpoint = torch.load(dir) name = net.__class__.__name__ if checkpoint.get(name) == None: return if name.find('FCDiscriminator') != -1 and cfg['training']['gan_resume'] == False: return if isinstance(checkpoint[name], list): self.adaptive_load_nets(net, checkpoint[name][0]["model_state"]) else: self.adaptive_load_nets(net, checkpoint[name]["model_state"]) if 'iter' in checkpoint: checkpoint_iter = checkpoint["iter"] else: checkpoint_iter = 0 if 'best_iou' in checkpoint: best_iou = checkpoint['best_iou'] else: best_iou = 0 logger.info( "Loaded checkpoint '{}' (iter {}) (best iou {}) for PredNet".format( dir, checkpoint_iter, best_iou ) ) else: raise Exception("No checkpoint found at '{}'".format(dir)) if hasattr(net, 'best_iou'): pass return best_iou def load_PredDnet(self, cfg, writer, logger, dir=None, net=None): # load pretrained weights on the net dir = dir or cfg['training']['Pred_resume'] best_iou = 0 if os.path.isfile(dir): logger.info( "Loading model and optimizer from checkpoint '{}'".format(dir) ) checkpoint = torch.load(dir) name = net.__class__.__name__ if checkpoint.get(name) == None: return if name.find('FCDiscriminator') != -1 and cfg['training']['gan_resume'] == False: return if isinstance(checkpoint[name], list): self.adaptive_load_nets(net, checkpoint[name][-1]["model_state"]) # attention-branch discriminator else: print("[WARNING] load discriminator maybe error!") self.adaptive_load_nets(net, checkpoint[name]["model_state"]) print("[INFO] {}: {}".format(name, net)) iter = checkpoint["iter"] logger.info( "Loaded checkpoint '{}' (iter {}) for PredNet".format( dir, checkpoint["iter"] ) ) else: raise Exception("No checkpoint found at '{}'".format(dir)) return best_iou def set_optimizer(self, optimizer): #set optimizer to all nets pass def reset_objective_SingleVector(self,): self.objective_vectors_group = torch.zeros(self.modal_num + 1, 19, 256).cuda() self.objective_vectors_num_group = torch.zeros(self.modal_num + 1, 19).cuda() self.objective_vectors_dis_group = torch.zeros(self.modal_num + 1, 19, 19).cuda() def update_objective_SingleVector(self, vectors, vectors_num, name='moving_average'): if torch.sum(vectors) == 0: return """ if self.objective_vectors_num_group[modal_idx][id] < 100: name = 'mean' """ if name == 'moving_average': self.objective_vectors_group = self.objective_vectors_group * 0.9999 + 0.0001 * vectors self.objective_vectors_num_group += vectors_num self.objective_vectors_num_group = min(self.objective_vectors_num_group, 3000) elif name == 'mean': self.objective_vectors_group = self.objective_vectors_group * self.objective_vectors_num_group.view(-1, 19, 1).expand(self.modal_num+1, 19, 256) + vectors self.objective_vectors_num_group = self.objective_vectors_num_group + vectors_num _objective_vectors_num_group = self.objective_vectors_num_group.clone() _ids = torch.where(_objective_vectors_num_group == 0) _objective_vectors_num_group[_ids] = 1.0 self.objective_vectors_group = self.objective_vectors_group / _objective_vectors_num_group.view(-1, 19, 1).expand(self.modal_num+1, 19, 256) self.objective_vectors_num_group = torch.min(self.objective_vectors_num_group, torch.Tensor([3000]).cuda()) else: raise NotImplementedError('no such updating way of objective vectors {}'.format(name))
class multisource_metatrainer(object): def __init__(self, args, nnclass, meta_update_lr, meta_update_step, beta, pretrain_mode='meta'): self.device = 1 self.generator_model = None self.generator_optim = None self.generator_criterion = None self.pretrain_mode = pretrain_mode self.batch_size = args.batch_size self.nnclass = nnclass self.init_generator(args) self.init_discriminator(args) self.init_optimizer(args) self.meta_update_lr = meta_update_lr self.meta_update_step = meta_update_step self.beta = beta def init_generator(self, args): self.generator_model = DeepLab(num_classes=self.nnclass, backbone='resnet', output_stride=16, sync_bn=None, freeze_bn=False).cuda() self.generator_model = torch.nn.DataParallel( self.generator_model).cuda() patch_replication_callback(self.generator_model) if args.resume: print('#--------- load pretrained model --------------#') model_dict = self.generator_model.module.state_dict() checkpoint = torch.load(args.resume) pretrained_dict = { k: v for k, v in checkpoint['state_dict'].items() if 'last_conv' not in k and k in model_dict.keys() } #pretrained_dict = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items() if 'last_conv' not in k} model_dict.update(pretrained_dict) self.generator_model.module.load_state_dict(model_dict) for param in self.generator_model.parameters(): param.requires_grad = True def init_discriminator(self, args): # init D self.discriminator_model = FCDiscriminator(num_classes=2).cuda() self.interp = nn.Upsample(size=400, mode='bilinear') self.disc_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) return def init_optimizer(self, args): self.generator_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss( mode='bce') #torch.nn.BCELoss(reduce ='mean') self.generator_params = [{ 'params': self.generator_model.module.get_1x_lr_params(), 'lr': args.lr }, { 'params': self.generator_model.module.get_10x_lr_params(), 'lr': args.lr * 10 }] self.discriminator_params = [{ 'params': self.discriminator_model.parameters(), 'lr': args.lr * 5 }] self.model_optim = torch.optim.Adadelta(self.generator_params + self.discriminator_params) self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, lr_step=30, iters_per_epoch=100) # for madan the src_image has shape B x source_index x channel x H x W def update_weights(self, srca, srca_labels, src_b, srcb_labels, target_img, target_label): #self.pretrain_mode = 'meta' src_labels = torch.cat([srca_labels.squeeze(), srcb_labels.squeeze()], 0).type(torch.LongTensor).cuda() src_image = torch.cat([srca.squeeze(), src_b.squeeze()]) if self.pretrain_mode == 'meta': seg_loss = self.meta_mldg(src_image, src_labels, self.batch_size) else: print('a default training is enabled') src_out, source_feature = self.generator_model(src_image) seg_loss = self.generator_criterion(src_out, src_labels) self.model_optim.zero_grad() seg_loss.backward() self.model_optim.step() target_logit, _ = self.generator_model(target_img.cuda()) tgt_loss = self.generator_criterion(target_logit, target_label) tgt_loss = tgt_loss.detach() seg_loss = seg_loss.detach() return seg_loss, tgt_loss def meta_mldg(self, src_image, src_labels, batch_size): batch_size = 4 num_src = 2 S = np.random.choice(num_src) V = abs(S - 1) source_out, _ = self.generator_model(src_image[S * batch_size:(S + 1) * batch_size].squeeze()) losses = self.generator_criterion( source_out, src_labels[S * batch_size:(S + 1) * batch_size]) for k in range(1, self.meta_update_step): source_out, _ = self.generator_model( src_image[S * batch_size:(S + 1) * batch_size].squeeze()) loss = self.generator_criterion( source_out, src_labels[S * batch_size:(S + 1) * batch_size]) grad = torch.autograd.grad(loss, self.generator_model.parameters()) fast_weights = list( map(lambda p: p[1] - self.meta_update_lr * p[0], zip(grad, self.generator_model.parameters()))) # compute the test loss on the fast weights Grad_test = self.generator_model(src_image[V * batch_size:(V + 1) * batch_size], fast_weights, bn_training=True) # compute the gradient on generator_model losses += self.beta * Grad_test return losses
class madan_trainer(object): def __init__(self, args, nnclass, ndomains): self.device = 1 self.generator_model = None self.generator_optim = None self.generator_criterion = None self.batch_size = args.batch_size self.nnclass = nnclass self.num_domains = ndomains self.init_wasserstein = Wasserstein() self.init_generator(args) self.init_discriminator(args) self.init_optimizer(args) def init_generator(self, args): self.generator_model = DeepLab(num_classes=self.nnclass, backbone='resnet', output_stride=16, sync_bn=None, freeze_bn=False).cuda() self.generator_model = torch.nn.DataParallel( self.generator_model).cuda() patch_replication_callback(self.generator_model) if args.resume: print('#--------- load pretrained model --------------#') model_dict = self.generator_model.module.state_dict() checkpoint = torch.load(args.resume) pretrained_dict = { k: v for k, v in checkpoint['state_dict'].items() if 'last_conv' not in k and k in model_dict.keys() } #pretrained_dict = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items() if 'last_conv' not in k} model_dict.update(pretrained_dict) self.generator_model.module.load_state_dict(model_dict) for param in self.generator_model.parameters(): param.requires_grad = True def init_discriminator(self, args): # init D self.discriminator_model = FCDiscriminator(num_classes=2).cuda() self.interp = nn.Upsample(size=400, mode='bilinear') self.disc_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) return def init_optimizer(self, args): self.generator_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss( mode='bce') #torch.nn.BCELoss(reduce ='mean') self.generator_params = [{ 'params': self.generator_model.module.get_1x_lr_params(), 'lr': args.lr }, { 'params': self.generator_model.module.get_10x_lr_params(), 'lr': args.lr * 10 }] self.discriminator_params = [{ 'params': self.discriminator_model.parameters(), 'lr': args.lr * 5 }] self.model_optim = torch.optim.Adadelta(self.generator_params + self.discriminator_params) self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, lr_step=30, iters_per_epoch=100) # for madan the src_image has shape B x source_index x channel x H x W def update_weights(self, src_image, src_labels, targ_image, targ_labels, options): running_loss = 0.0 src_labels = torch.cat( [src_labels[:, 0].squeeze(), src_labels[:, 1].squeeze()], 0).type(torch.LongTensor).cuda() self.model_optim.zero_grad() # src image shape batch_size x domain x 3 channels x height x width src_out, source_feature = self.generator_model( torch.cat([src_image[:, 0].squeeze(), src_image[:, 1].squeeze()])) targ_out, target_feature = self.generator_model(targ_image) # Discriminator discriminator_x = torch.cat([source_feature, target_feature]).squeeze() disc_clf = self.discriminator_model(discriminator_x) # Losses losses = torch.stack([ self.generator_criterion( src_out[j * self.batch_size:j + self.batch_size], src_labels[j * self.batch_size:j + self.batch_size]) for j in range(self.num_domains) ]) slabels = torch.ones(self.batch_size, disc_clf.shape[2], disc_clf.shape[3], requires_grad=False).type( torch.LongTensor).cuda() tlabels = torch.zeros(self.batch_size * 2, disc_clf.shape[2], disc_clf.shape[3], requires_grad=False).type( torch.LongTensor).cuda() domain_losses = torch.stack([ self.generator_criterion( disc_clf[j * self.batch_size:j + self.batch_size].squeeze(), slabels) for j in range(self.num_domains) ]) domain_losses = torch.cat([ domain_losses, self.generator_criterion( disc_clf[2 * self.batch_size:2 * self.batch_size + 2 * self.batch_size].squeeze(), tlabels).view(-1) ]) # Different final loss function depending on different training modes. if options['mode'] == "maxmin": loss = torch.max(losses) + options['mu'] * torch.min(domain_losses) elif options['mode'] == "dynamic": loss = torch.log( torch.sum( torch.exp(options['gamma'] * (losses + options['mu'] * domain_losses))) ) / options['gamma'] else: raise ValueError( "No support for the training mode on madnNet: {}.".format( options['mode'])) loss.backward() self.model_optim.step() running_loss += loss.detach().cpu().numpy() # compute target loss target_loss = self.generator_criterion( targ_out, targ_labels).detach().cpu().numpy() return running_loss, target_loss def update_wasserstein(self, src_image, src_labels, targ_image, targ_labels, options): running_loss = 0.0 src_labels = torch.cat( [src_labels[:, 0].squeeze(), src_labels[:, 1].squeeze()], 0).type(torch.LongTensor).cuda() self.model_optim.zero_grad() # src image shape batch_size x domain x 3 channels x height x width src_out, source_feature = self.generator_model( torch.cat([src_image[:, 0].squeeze(), src_image[:, 1].squeeze()])) targ_out, target_feature = self.generator_model(targ_image) # Discriminator discriminator_x = torch.cat([source_feature, target_feature]).squeeze() disc_clf = self.discriminator_model(discriminator_x) # Losses losses = torch.stack([ self.generator_criterion( src_out[j * self.batch_size:j * self.batch_size + self.batch_size], src_labels[j * self.batch_size:j * self.batch_size + self.batch_size]) for j in range(self.num_domains) ]) wass_loss = [ self.init_wasserstein.update_wasserstein_dual_source( disc_clf[j * self.batch_size:j * self.batch_size + self.batch_size].squeeze(), disc_clf[self.num_domains * self.batch_size:self.num_domains * self.batch_size + self.batch_size].squeeze()) for j in range(self.num_domains) ] domain_losses = torch.stack(wass_loss) # compute gradient penalty penalty_cup, penalty_disc = self.init_wasserstein.gradient_regularization_dual_source( self.discriminator_model, source_feature.detach(), target_feature.detach(), options['batch_size'], options['num_domains']) # Different final loss function depending on different training modes. if options['mode'] == "maxmin": loss = torch.max( losses) + options['mu'] * torch.min(domain_losses) + options[ 'gamma'] * penalty_cup + options['gamma'] * penalty_disc elif options['mode'] == "dynamic": # TODO Wasserstein not implemented yet for this loss = torch.log( torch.sum( torch.exp(options['gamma'] * (losses + options['mu'] * domain_losses))) ) / options['gamma'] else: raise ValueError( "No support for the training mode on madnNet: {}.".format( options['mode'])) loss.backward() self.model_optim.step() for p in self.discriminator_model.parameters(): p.data.clamp_(-0.01, 0.01) running_loss += loss.detach().cpu().numpy() # compute target loss target_loss = self.generator_criterion( targ_out, targ_labels).detach().cpu().numpy() return running_loss, target_loss
def main(): """Create the model and start the training.""" args = get_arguments() cudnn.enabled = True n_discriminators = 5 # create teacher & student student_net = UNet(3, n_classes=args.num_classes) teacher_net = UNet(3, n_classes=args.num_classes) student_params = list(student_net.parameters()) # teacher doesn't need gradient as it's just a EMA of the student teacher_params = list(teacher_net.parameters()) for param in teacher_params: param.requires_grad = False student_net.train() student_net.cuda(args.gpu) teacher_net.train() teacher_net.cuda(args.gpu) cudnn.benchmark = True unsup_weights = [ args.unsup_weight5, args.unsup_weight6, args.unsup_weight7, args.unsup_weight8, args.unsup_weight9 ] lambda_adv_tgts = [ args.lambda_adv_tgt5, args.lambda_adv_tgt6, args.lambda_adv_tgt7, args.lambda_adv_tgt8, args.lambda_adv_tgt9 ] # create a list of discriminators discriminators = [] for dis_idx in range(n_discriminators): discriminators.append(FCDiscriminator(num_classes=args.num_classes)) discriminators[dis_idx].train() discriminators[dis_idx].cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) max_iters = args.num_steps * args.iter_size * args.batch_size src_set = REFUGE(True, domain='REFUGE_SRC', is_transform=True, augmentations=aug_student, aug_for_target=aug_teacher, max_iters=max_iters) src_loader = data.DataLoader(src_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) src_loader_iter = enumerate(src_loader) tgt_set = REFUGE(True, domain='REFUGE_DST', is_transform=True, augmentations=aug_student, aug_for_target=aug_teacher, max_iters=max_iters) tgt_loader = data.DataLoader(tgt_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) tgt_loader_iter = enumerate(tgt_loader) student_optimizer = optim.SGD(student_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) teacher_optimizer = optim_weight_ema.WeightEMA(teacher_params, student_params, alpha=args.teacher_alpha) d_optimizers = [] for idx in range(n_discriminators): optimizer = optim.Adam(discriminators[idx].parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) d_optimizers.append(optimizer) calc_bce_loss = torch.nn.BCEWithLogitsLoss() # labels for adversarial training source_label, tgt_label = 0, 1 for i_iter in range(args.num_steps): total_seg_loss = 0 seg_loss_vals = [0] * n_discriminators adv_tgt_loss_vals = [0] * n_discriminators d_loss_vals = [0] * n_discriminators unsup_loss_vals = [0] * n_discriminators for d_optimizer in d_optimizers: d_optimizer.zero_grad() adjust_learning_rate_D(d_optimizer, i_iter, args) student_optimizer.zero_grad() adjust_learning_rate(student_optimizer, i_iter, args) for sub_i in range(args.iter_size): # ******** Optimize source network with segmentation loss ******** # As we don't change the discriminators, their parameters are fixed for discriminator in discriminators: for param in discriminator.parameters(): param.requires_grad = False _, src_batch = src_loader_iter.__next__() _, _, src_images, src_labels, _ = src_batch src_images = Variable(src_images).cuda(args.gpu) # calculate the segmentation losses sup_preds = list(student_net(src_images)) seg_losses, total_seg_loss = [], 0 for idx, sup_pred in enumerate(sup_preds): sup_interp_pred = (sup_pred) # you also can use dice loss like: dice_loss(src_labels, sup_interp_pred) seg_loss = Weighted_Jaccard_loss(src_labels, sup_interp_pred, args.class_weights, args.gpu) seg_losses.append(seg_loss) total_seg_loss += seg_loss * unsup_weights[idx] seg_loss_vals[idx] += seg_loss.item() / args.iter_size _, tgt_batch = tgt_loader_iter.__next__() tgt_images0, tgt_lbl0, tgt_images1, tgt_lbl1, _ = tgt_batch tgt_images0 = Variable(tgt_images0).cuda(args.gpu) tgt_images1 = Variable(tgt_images1).cuda(args.gpu) # calculate ensemble losses stu_unsup_preds = list(student_net(tgt_images1)) tea_unsup_preds = teacher_net(tgt_images0) total_mse_loss = 0 for idx in range(n_discriminators): stu_unsup_probs = F.softmax(stu_unsup_preds[idx], dim=-1) tea_unsup_probs = F.softmax(tea_unsup_preds[idx], dim=-1) unsup_loss = calc_mse_loss(stu_unsup_probs, tea_unsup_probs, args.batch_size) unsup_loss_vals[idx] += unsup_loss.item() / args.iter_size total_mse_loss += unsup_loss * unsup_weights[idx] total_mse_loss = total_mse_loss / args.iter_size # As the requires_grad is set to False in the discriminator, the # gradients are only accumulated in the generator, the target # student network is optimized to make the outputs of target domain # images close to the outputs of source domain images stu_unsup_preds = list(student_net(tgt_images0)) d_outs, total_adv_loss = [], 0 for idx in range(n_discriminators): stu_unsup_interp_pred = (stu_unsup_preds[idx]) d_outs.append(discriminators[idx](stu_unsup_interp_pred)) label_size = d_outs[idx].data.size() labels = torch.FloatTensor(label_size).fill_(source_label) labels = Variable(labels).cuda(args.gpu) adv_tgt_loss = calc_bce_loss(d_outs[idx], labels) total_adv_loss += lambda_adv_tgts[idx] * adv_tgt_loss adv_tgt_loss_vals[idx] += adv_tgt_loss.item() / args.iter_size total_adv_loss = total_adv_loss / args.iter_size # requires_grad is set to True in the discriminator, we only # accumulate gradients in the discriminators, the discriminators are # optimized to make true predictions d_losses = [] for idx in range(n_discriminators): discriminator = discriminators[idx] for param in discriminator.parameters(): param.requires_grad = True sup_preds[idx] = sup_preds[idx].detach() d_outs[idx] = discriminators[idx](sup_preds[idx]) label_size = d_outs[idx].data.size() labels = torch.FloatTensor(label_size).fill_(source_label) labels = Variable(labels).cuda(args.gpu) d_losses.append(calc_bce_loss(d_outs[idx], labels)) d_losses[idx] = d_losses[idx] / args.iter_size / 2 d_losses[idx].backward() d_loss_vals[idx] += d_losses[idx].item() for idx in range(n_discriminators): stu_unsup_preds[idx] = stu_unsup_preds[idx].detach() d_outs[idx] = discriminators[idx](stu_unsup_preds[idx]) label_size = d_outs[idx].data.size() labels = torch.FloatTensor(label_size).fill_(tgt_label) labels = Variable(labels).cuda(args.gpu) d_losses[idx] = calc_bce_loss(d_outs[idx], labels) d_losses[idx] = d_losses[idx] / args.iter_size / 2 d_losses[idx].backward() d_loss_vals[idx] += d_losses[idx].item() for d_optimizer in d_optimizers: d_optimizer.step() total_loss = total_seg_loss + total_adv_loss + total_mse_loss total_loss.backward() student_optimizer.step() teacher_optimizer.step() log_str = 'iter = {0:7d}/{1:7d}'.format(i_iter, args.num_steps) log_str += ', total_seg_loss = {0:.3f} '.format(total_seg_loss) templ = 'seg_losses = [' + ', '.join(['%.2f'] * len(seg_loss_vals)) log_str += templ % tuple(seg_loss_vals) + '] ' templ = 'ens_losses = [' + ', '.join(['%.5f'] * len(unsup_loss_vals)) log_str += templ % tuple(unsup_loss_vals) + '] ' templ = 'adv_losses = [' + ', '.join(['%.2f'] * len(adv_tgt_loss_vals)) log_str += templ % tuple(adv_tgt_loss_vals) + '] ' templ = 'd_losses = [' + ', '.join(['%.2f'] * len(d_loss_vals)) log_str += templ % tuple(d_loss_vals) + '] ' print(log_str) if i_iter >= args.num_steps_stop - 1: print('save model ...') filename = 'UNet' + str( args.num_steps_stop) + '_v18_weightedclass.pth' torch.save(teacher_net.cpu().state_dict(), os.path.join(args.snapshot_dir, filename)) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') filename = 'UNet' + str(i_iter) + '_v18_weightedclass.pth' torch.save(teacher_net.cpu().state_dict(), os.path.join(args.snapshot_dir, filename)) teacher_net.cuda(args.gpu)
class adda_trainer(object): def __init__(self, args, nnclass): self.target_model = None self.target_optim = None self.target_criterion = None self.batch_size = args.batch_size self.nnclass = nnclass self.init_target(args) self.init_discriminator(args) self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, lr_step=40, iters_per_epoch=100) self.disc_params = [{ 'params': self.disc_model.parameters(), 'lr': args.lr * 5 }] self.dda_optim = torch.optim.Adam(self.train_params) self.discriminator_optim = torch.optim.Adam(self.disc_params) #self.dda_optim = torch.optim.SGD(self.train_params, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) #self.discriminator_optim = torch.optim.SGD(self.disc_params, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) self.adv_aug = FastGradientSignUntargeted(self.target_model, 0.0157, 0.00784, min_val=0, max_val=1, max_iters=2, _type='linf') def init_target(self, args): self.target_model = DeepLab(num_classes=self.nnclass, backbone='resnet', output_stride=16, sync_bn=None, freeze_bn=False) self.train_params = [{ 'params': self.target_model.get_1x_lr_params(), 'lr': args.lr }, { 'params': self.target_model.get_10x_lr_params(), 'lr': args.lr * 10 }] self.target_model = torch.nn.DataParallel(self.target_model) self.target_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss( mode='bce') #torch.nn.BCELoss(reduce ='mean') patch_replication_callback(self.target_model) model_dict = self.target_model.module.state_dict() checkpoint = torch.load(args.resume) pretrained_dict = { k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() } #pretrained_dict = {k:v for k,v in checkpoint['state_dict'].items() if 'last_conv' not in k } model_dict.update(pretrained_dict) self.target_model.module.load_state_dict(model_dict) self.target_model = self.target_model.cuda() return def init_discriminator(self, args): # init D self.disc_model = FCDiscriminator(num_classes=2).cuda() self.interp = nn.Upsample(size=400, mode='bilinear') self.disc_criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) return def update_weights(self, input_, src_labels, target, tgt_labels, lamda_g, trainmodel): self.dda_optim.zero_grad() self.discriminator_optim.zero_grad() if trainmodel == 'train_gen': for param in self.target_model.parameters(): param.requires_grad = True for param in self.disc_model.parameters(): param.requires_grad = False self.disc_model.eval() self.target_model.train() else: for param in self.target_model.parameters(): param.requires_grad = False for param in self.disc_model.parameters(): param.requires_grad = True self.disc_model.train() self.target_model.eval() #tot_input = torch.cat([input_, target]) #import pdb #pdb.set_trace() src_out, source_feature = self.target_model(input_) seg_loss = self.target_criterion(src_out, src_labels) #print(target.shape) targ_out, target_feature = self.target_model(target) # discriminator discriminator_x = torch.cat([source_feature, target_feature]).squeeze() discriminator_adv_logit = torch.cat([ torch.zeros(source_feature.shape), torch.ones(target_feature.shape) ]) discriminator_real_logit = torch.cat([ torch.ones(source_feature.shape), torch.zeros(target_feature.shape) ]) disc_out = self.disc_model(discriminator_x) #print(source_feature.shape, input_.shape,discriminator_adv_logit.shape, disc_out.shape) adv_loss = self.target_criterion( disc_out, discriminator_adv_logit[:, 0, :, :].cuda()) adv_loss += self.target_criterion( disc_out, discriminator_adv_logit[:, 1, :, :].cuda()) disc_loss = self.disc_criterion( disc_out, discriminator_real_logit[:, 0, :, :].cuda()) disc_loss += self.disc_criterion( disc_out, discriminator_real_logit[:, 1, :, :].cuda()) if trainmodel == 'train_gen': loss_seg = seg_loss + lamda_g * adv_loss loss_seg.backward() self.dda_optim.step() else: disc_loss.backward() self.discriminator_optim.step() tgt_loss = self.target_criterion(targ_out, tgt_labels) return seg_loss.data.cpu().numpy(), tgt_loss.data.cpu().numpy()
def __init__(self, opt, logger, isTrain=True): self.opt = opt self.class_numbers = opt.n_class self.logger = logger self.best_iou = -100 self.nets = [] self.nets_DP = [] self.default_gpu = 0 self.objective_vectors = torch.zeros([self.class_numbers, 256]) self.objective_vectors_num = torch.zeros([self.class_numbers]) if opt.bn == 'sync_bn': BatchNorm = SynchronizedBatchNorm2d elif opt.bn == 'bn': BatchNorm = nn.BatchNorm2d else: raise NotImplementedError('batch norm choice {} is not implemented'.format(opt.bn)) if self.opt.no_resume: restore_from = None else: restore_from= opt.resume_path self.best_iou = 0 if self.opt.student_init == 'imagenet': self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from) elif self.opt.student_init == 'simclr': self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, initialization=os.path.join(opt.root, 'Code/ProDA', 'pretrained/simclr/r101_1x_sk0.pth'), bn_clr=opt.bn_clr) else: self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from) logger.info('the backbone is {}'.format(opt.model_name)) self.nets.extend([self.BaseNet]) self.optimizers = [] self.schedulers = [] optimizer_cls = torch.optim.SGD optimizer_params = {'lr':opt.lr, 'weight_decay':2e-4, 'momentum':0.9} if self.opt.stage == 'warm_up': self.net_D = FCDiscriminator(inplanes=self.class_numbers) self.net_D_DP = self.init_device(self.net_D, gpu_id=self.default_gpu, whether_DP=True) self.nets.extend([self.net_D]) self.nets_DP.append(self.net_D_DP) self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=1e-4, betas=(0.9, 0.99)) self.optimizers.extend([self.optimizer_D]) self.DSchedule = get_scheduler(self.optimizer_D, opt) self.schedulers.extend([self.DSchedule]) if self.opt.finetune or self.opt.stage == 'warm_up': self.BaseOpti = optimizer_cls([{'params':self.BaseNet.get_1x_lr_params(), 'lr':optimizer_params['lr']}, {'params':self.BaseNet.get_10x_lr_params(), 'lr':optimizer_params['lr']*10}], **optimizer_params) else: self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params) self.optimizers.extend([self.BaseOpti]) self.BaseSchedule = get_scheduler(self.BaseOpti, opt) self.schedulers.extend([self.BaseSchedule]) if self.opt.ema: self.BaseNet_ema = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, bn_clr=opt.ema_bn) self.BaseNet_ema.load_state_dict(self.BaseNet.state_dict().copy()) if self.opt.distillation > 0: self.teacher = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=opt.resume_path, bn_clr=opt.ema_bn) self.teacher.eval() self.teacher_DP = self.init_device(self.teacher, gpu_id=self.default_gpu, whether_DP=True) self.adv_source_label = 0 self.adv_target_label = 1 if self.opt.gan == 'Vanilla': self.bceloss = nn.BCEWithLogitsLoss(size_average=True) elif self.opt.gan == 'LS': self.bceloss = torch.nn.MSELoss() self.feat_prototype_distance_DP = self.init_device(feat_prototype_distance_module(), gpu_id=self.default_gpu, whether_DP=True) self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True) self.nets_DP.append(self.BaseNet_DP) if self.opt.ema: self.BaseNet_ema_DP = self.init_device(self.BaseNet_ema, gpu_id=self.default_gpu, whether_DP=True)
class CustomModel(): def __init__(self, opt, logger, isTrain=True): self.opt = opt self.class_numbers = opt.n_class self.logger = logger self.best_iou = -100 self.nets = [] self.nets_DP = [] self.default_gpu = 0 self.objective_vectors = torch.zeros([self.class_numbers, 256]) self.objective_vectors_num = torch.zeros([self.class_numbers]) if opt.bn == 'sync_bn': BatchNorm = SynchronizedBatchNorm2d elif opt.bn == 'bn': BatchNorm = nn.BatchNorm2d else: raise NotImplementedError('batch norm choice {} is not implemented'.format(opt.bn)) if self.opt.no_resume: restore_from = None else: restore_from= opt.resume_path self.best_iou = 0 if self.opt.student_init == 'imagenet': self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from) elif self.opt.student_init == 'simclr': self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, initialization=os.path.join(opt.root, 'Code/ProDA', 'pretrained/simclr/r101_1x_sk0.pth'), bn_clr=opt.bn_clr) else: self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from) logger.info('the backbone is {}'.format(opt.model_name)) self.nets.extend([self.BaseNet]) self.optimizers = [] self.schedulers = [] optimizer_cls = torch.optim.SGD optimizer_params = {'lr':opt.lr, 'weight_decay':2e-4, 'momentum':0.9} if self.opt.stage == 'warm_up': self.net_D = FCDiscriminator(inplanes=self.class_numbers) self.net_D_DP = self.init_device(self.net_D, gpu_id=self.default_gpu, whether_DP=True) self.nets.extend([self.net_D]) self.nets_DP.append(self.net_D_DP) self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=1e-4, betas=(0.9, 0.99)) self.optimizers.extend([self.optimizer_D]) self.DSchedule = get_scheduler(self.optimizer_D, opt) self.schedulers.extend([self.DSchedule]) if self.opt.finetune or self.opt.stage == 'warm_up': self.BaseOpti = optimizer_cls([{'params':self.BaseNet.get_1x_lr_params(), 'lr':optimizer_params['lr']}, {'params':self.BaseNet.get_10x_lr_params(), 'lr':optimizer_params['lr']*10}], **optimizer_params) else: self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params) self.optimizers.extend([self.BaseOpti]) self.BaseSchedule = get_scheduler(self.BaseOpti, opt) self.schedulers.extend([self.BaseSchedule]) if self.opt.ema: self.BaseNet_ema = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, bn_clr=opt.ema_bn) self.BaseNet_ema.load_state_dict(self.BaseNet.state_dict().copy()) if self.opt.distillation > 0: self.teacher = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=opt.resume_path, bn_clr=opt.ema_bn) self.teacher.eval() self.teacher_DP = self.init_device(self.teacher, gpu_id=self.default_gpu, whether_DP=True) self.adv_source_label = 0 self.adv_target_label = 1 if self.opt.gan == 'Vanilla': self.bceloss = nn.BCEWithLogitsLoss(size_average=True) elif self.opt.gan == 'LS': self.bceloss = torch.nn.MSELoss() self.feat_prototype_distance_DP = self.init_device(feat_prototype_distance_module(), gpu_id=self.default_gpu, whether_DP=True) self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True) self.nets_DP.append(self.BaseNet_DP) if self.opt.ema: self.BaseNet_ema_DP = self.init_device(self.BaseNet_ema, gpu_id=self.default_gpu, whether_DP=True) def calculate_mean_vector(self, feat_cls, outputs, labels=None, thresh=None): outputs_softmax = F.softmax(outputs, dim=1) if thresh is None: thresh = -1 conf = outputs_softmax.max(dim=1, keepdim=True)[0] mask = conf.ge(thresh) outputs_argmax = outputs_softmax.argmax(dim=1, keepdim=True) outputs_argmax = self.process_label(outputs_argmax.float()) if labels is None: outputs_pred = outputs_argmax else: labels_expanded = self.process_label(labels) outputs_pred = labels_expanded * outputs_argmax scale_factor = F.adaptive_avg_pool2d(outputs_pred * mask, 1) vectors = [] ids = [] for n in range(feat_cls.size()[0]): for t in range(self.class_numbers): if scale_factor[n][t].item()==0: continue if (outputs_pred[n][t] > 0).sum() < 10: continue s = feat_cls[n] * outputs_pred[n][t] * mask[n] # scale = torch.sum(outputs_pred[n][t]) / labels.shape[2] / labels.shape[3] * 2 # s = normalisation_pooling()(s, scale) s = F.adaptive_avg_pool2d(s, 1) / scale_factor[n][t] vectors.append(s) ids.append(t) return vectors, ids def step_adv(self, source_x, source_label, target_x, source_imageS, source_params): for param in self.net_D.parameters(): param.requires_grad = False self.BaseOpti.zero_grad() if self.opt.S_pseudo_src > 0: source_output = self.BaseNet_DP(source_imageS) source_label_d4 = F.interpolate(source_label.unsqueeze(1).float(), size=source_output['out'].size()[2:]) source_labelS = self.label_strong_T(source_label_d4.clone().float(), source_params, padding=250, scale=4).to(torch.int64) loss_ = cross_entropy2d(input=source_output['out'], target=source_labelS.squeeze(1)) loss_GTA = loss_ * self.opt.S_pseudo_src source_outputUp = F.interpolate(source_output['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True) else: source_output = self.BaseNet_DP(source_x, ssl=True) source_outputUp = F.interpolate(source_output['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True) loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label, size_average=True, reduction='mean') target_output = self.BaseNet_DP(target_x, ssl=True) target_outputUp = F.interpolate(target_output['out'], size=target_x.size()[2:], mode='bilinear', align_corners=True) target_D_out = self.net_D_DP(F.softmax(target_outputUp, dim=1)) loss_adv_G = self.bceloss(target_D_out, torch.FloatTensor(target_D_out.data.size()).fill_(self.adv_source_label).to(target_D_out.device)) * self.opt.adv loss_G = loss_adv_G + loss_GTA loss_G.backward() self.BaseOpti.step() for param in self.net_D.parameters(): param.requires_grad = True self.optimizer_D.zero_grad() source_D_out = self.net_D_DP(F.softmax(source_outputUp.detach(), dim=1)) target_D_out = self.net_D_DP(F.softmax(target_outputUp.detach(), dim=1)) loss_D = self.bceloss(source_D_out, torch.FloatTensor(source_D_out.data.size()).fill_(self.adv_source_label).to(source_D_out.device)) + \ self.bceloss(target_D_out, torch.FloatTensor(target_D_out.data.size()).fill_(self.adv_target_label).to(target_D_out.device)) loss_D.backward() self.optimizer_D.step() return loss_GTA.item(), loss_adv_G.item(), loss_D.item() def step(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None, target_lpsoft=None, target_image_full=None, target_weak_params=None): source_out = self.BaseNet_DP(source_x, ssl=True) source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True) loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label) loss_GTA.backward() if self.opt.proto_rectify: threshold_arg = F.interpolate(target_lpsoft, scale_factor=0.25, mode='bilinear', align_corners=True) else: threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long() if self.opt.ema: ema_input = target_image_full with torch.no_grad(): ema_out = self.BaseNet_ema_DP(ema_input) ema_out['feat'] = F.interpolate(ema_out['feat'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True) ema_out['out'] = F.interpolate(ema_out['out'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True) target_out = self.BaseNet_DP(target_imageS) if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_x) target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True) target_out['feat'] = F.interpolate(target_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True) loss = torch.Tensor([0]).to(self.default_gpu) batch, _, w, h = threshold_arg.shape if self.opt.proto_rectify: weights = self.get_prototype_weight(ema_out['feat'], target_weak_params=target_weak_params) rectified = weights * threshold_arg threshold_arg = rectified.max(1, keepdim=True)[1] rectified = rectified / rectified.sum(1, keepdim=True) argmax = rectified.max(1, keepdim=True)[0] threshold_arg[argmax < self.opt.train_thred] = 250 if self.opt.S_pseudo > 0: threshold_argS = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64) cluster_argS = self.label_strong_T(cluster_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64) threshold_arg = threshold_argS loss_CTS = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h])) if self.opt.rce: rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone()) loss_CTS = self.opt.rce_alpha * loss_CTS + self.opt.rce_beta * rce if self.opt.regular_w > 0: regular_loss = self.regular_loss(target_out['out']) loss_CTS = loss_CTS + regular_loss * self.opt.regular_w cluster_argS = None loss_consist = torch.Tensor([0]).to(self.default_gpu) if self.opt.proto_consistW > 0: ema2weak_feat = self.full2weak(ema_out['feat'], target_weak_params) #N*256*H*W ema2weak_feat_proto_distance = self.feat_prototype_distance(ema2weak_feat) #N*19*H*W ema2strong_feat_proto_distance = self.label_strong_T(ema2weak_feat_proto_distance, target_params, padding=250, scale=4) mask = (ema2strong_feat_proto_distance != 250).float() teacher = F.softmax(-ema2strong_feat_proto_distance * self.opt.proto_temperature, dim=1) targetS_out = target_out if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_imageS) targetS_out['out'] = F.interpolate(targetS_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True) targetS_out['feat'] = F.interpolate(targetS_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True) prototype_tmp = self.objective_vectors.expand(4, -1, -1) #gpu memory limitation strong_feat_proto_distance = self.feat_prototype_distance_DP(targetS_out['feat'], prototype_tmp, self.class_numbers) student = F.log_softmax(-strong_feat_proto_distance * self.opt.proto_temperature, dim=1) loss_consist = F.kl_div(student, teacher, reduction='none') loss_consist = (loss_consist * mask).sum() / mask.sum() loss = loss + self.opt.proto_consistW * loss_consist loss = loss + loss_CTS loss.backward() self.BaseOpti.step() self.BaseOpti.zero_grad() if self.opt.moving_prototype: #update prototype ema_vectors, ema_ids = self.calculate_mean_vector(ema_out['feat'].detach(), ema_out['out'].detach()) for t in range(len(ema_ids)): self.update_objective_SingleVector(ema_ids[t], ema_vectors[t].detach(), start_mean=False) if self.opt.ema: #update ema model for param_q, param_k in zip(self.BaseNet.parameters(), self.BaseNet_ema.parameters()): param_k.data = param_k.data.clone() * 0.999 + param_q.data.clone() * (1. - 0.999) for buffer_q, buffer_k in zip(self.BaseNet.buffers(), self.BaseNet_ema.buffers()): buffer_k.data = buffer_q.data.clone() return loss.item(), loss_CTS.item(), loss_consist.item() def regular_loss(self, activation): logp = F.log_softmax(activation, dim=1) if self.opt.regular_type == 'MRENT': p = F.softmax(activation, dim=1) loss = (p * logp).sum() / (p.shape[0]*p.shape[2]*p.shape[3]) elif self.opt.regular_type == 'MRKLD': loss = - logp.sum() / (logp.shape[0]*logp.shape[1]*logp.shape[2]*logp.shape[3]) return loss def rce(self, pred, labels): pred = F.softmax(pred, dim=1) pred = torch.clamp(pred, min=1e-7, max=1.0) mask = (labels != 250).float() labels[labels==250] = self.class_numbers label_one_hot = torch.nn.functional.one_hot(labels, self.class_numbers + 1).float().to(self.default_gpu) label_one_hot = torch.clamp(label_one_hot.permute(0,3,1,2)[:,:-1,:,:], min=1e-4, max=1.0) rce = -(torch.sum(pred * torch.log(label_one_hot), dim=1) * mask).sum() / (mask.sum() + 1e-6) return rce def step_distillation(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None): source_out = self.BaseNet_DP(source_x, ssl=True) source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True) loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label) loss_GTA.backward() threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long() if self.opt.S_pseudo > 0: threshold_arg = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64) target_out = self.BaseNet_DP(target_imageS) else: target_out = self.BaseNet_DP(target_x) target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True) batch, _, w, h = threshold_arg.shape loss = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]), size_average=True, reduction='mean') if self.opt.rce: rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone()) loss = self.opt.rce_alpha * loss + self.opt.rce_beta * rce if self.opt.distillation > 0: student = F.softmax(target_out['out'], dim=1) with torch.no_grad(): teacher_out = self.teacher_DP(target_imageS) teacher_out['out'] = F.interpolate(teacher_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True) teacher = F.softmax(teacher_out['out'], dim=1) loss_kd = F.kl_div(student, teacher, reduction='none') mask = (teacher != 250).float() loss_kd = (loss_kd * mask).sum() / mask.sum() loss = loss + self.opt.distillation * loss_kd loss.backward() self.BaseOpti.step() self.BaseOpti.zero_grad() return loss_GTA.item(), loss.item() def full2weak(self, feat, target_weak_params): tmp = [] for i in range(feat.shape[0]): h, w = target_weak_params['RandomSized'][0][i], target_weak_params['RandomSized'][1][i] feat_ = F.interpolate(feat[i:i+1], size=[int(h/4), int(w/4)], mode='bilinear', align_corners=True) y1, y2, x1, x2 = target_weak_params['RandomCrop'][0][i], target_weak_params['RandomCrop'][1][i], target_weak_params['RandomCrop'][2][i], target_weak_params['RandomCrop'][3][i] y1, th, x1, tw = int(y1/4), int((y2-y1)/4), int(x1/4), int((x2-x1)/4) feat_ = feat_[:, :, y1:y1+th, x1:x1+tw] if target_weak_params['RandomHorizontallyFlip'][i]: inv_idx = torch.arange(feat_.size(3)-1,-1,-1).long().to(feat_.device) feat_ = feat_.index_select(3,inv_idx) tmp.append(feat_) feat = torch.cat(tmp, 0) return feat def feat_prototype_distance(self, feat): N, C, H, W = feat.shape feat_proto_distance = -torch.ones((N, self.class_numbers, H, W)).to(feat.device) for i in range(self.class_numbers): #feat_proto_distance[:, i, :, :] = torch.norm(torch.Tensor(self.objective_vectors[i]).reshape(-1,1,1).expand(-1, H, W).to(feat.device) - feat, 2, dim=1,) feat_proto_distance[:, i, :, :] = torch.norm(self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W) - feat, 2, dim=1,) return feat_proto_distance def get_prototype_weight(self, feat, label=None, target_weak_params=None): feat = self.full2weak(feat, target_weak_params) feat_proto_distance = self.feat_prototype_distance(feat) feat_nearest_proto_distance, feat_nearest_proto = feat_proto_distance.min(dim=1, keepdim=True) feat_proto_distance = feat_proto_distance - feat_nearest_proto_distance weight = F.softmax(-feat_proto_distance * self.opt.proto_temperature, dim=1) return weight def label_strong_T(self, label, params, padding, scale=1): label = label + 1 for i in range(label.shape[0]): for (Tform, param) in params.items(): if Tform == 'Hflip' and param[i].item() == 1: label[i] = label[i].clone().flip(-1) elif (Tform == 'ShearX' or Tform == 'ShearY' or Tform == 'TranslateX' or Tform == 'TranslateY' or Tform == 'Rotate') and param[i].item() != 1e4: v = int(param[i].item() // scale) if Tform == 'TranslateX' or Tform == 'TranslateY' else param[i].item() label[i:i+1] = affine_sample(label[i:i+1].clone(), v, Tform) elif Tform == 'CutoutAbs' and isinstance(param, list): x0 = int(param[0][i].item() // scale) y0 = int(param[1][i].item() // scale) x1 = int(param[2][i].item() // scale) y1 = int(param[3][i].item() // scale) label[i, :, y0:y1, x0:x1] = 0 label[label == 0] = padding + 1 # for strong augmentation, constant padding label = label - 1 return label def process_label(self, label): batch, channel, w, h = label.size() pred1 = torch.zeros(batch, self.class_numbers + 1, w, h).to(self.default_gpu) id = torch.where(label < self.class_numbers, label, torch.Tensor([self.class_numbers]).to(self.default_gpu)) pred1 = pred1.scatter_(1, id.long(), 1) return pred1 def freeze_bn_apply(self): for net in self.nets: net.apply(freeze_bn) for net in self.nets_DP: net.apply(freeze_bn) def scheduler_step(self): for scheduler in self.schedulers: scheduler.step() def optimizer_zerograd(self): for optimizer in self.optimizers: optimizer.zero_grad() def init_device(self, net, gpu_id=None, whether_DP=False): gpu_id = gpu_id or self.default_gpu device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else 'cpu') net = net.to(device) # if torch.cuda.is_available(): if whether_DP: #net = DataParallelWithCallback(net, device_ids=[0]) net = DataParallelWithCallback(net, device_ids=range(torch.cuda.device_count())) return net def eval(self, net=None, logger=None): """Make specific models eval mode during test time""" # if issubclass(net, nn.Module) or issubclass(net, BaseModel): if net == None: for net in self.nets: net.eval() for net in self.nets_DP: net.eval() if logger!=None: logger.info("Successfully set the model eval mode") else: net.eval() if logger!=None: logger("Successfully set {} eval mode".format(net.__class__.__name__)) return def train(self, net=None, logger=None): if net==None: for net in self.nets: net.train() for net in self.nets_DP: net.train() else: net.train() return def update_objective_SingleVector(self, id, vector, name='moving_average', start_mean=True): if vector.sum().item() == 0: return if start_mean and self.objective_vectors_num[id].item() < 100: name = 'mean' if name == 'moving_average': self.objective_vectors[id] = self.objective_vectors[id] * (1 - self.opt.proto_momentum) + self.opt.proto_momentum * vector.squeeze() self.objective_vectors_num[id] += 1 self.objective_vectors_num[id] = min(self.objective_vectors_num[id], 3000) elif name == 'mean': self.objective_vectors[id] = self.objective_vectors[id] * self.objective_vectors_num[id] + vector.squeeze() self.objective_vectors_num[id] += 1 self.objective_vectors[id] = self.objective_vectors[id] / self.objective_vectors_num[id] self.objective_vectors_num[id] = min(self.objective_vectors_num[id], 3000) pass else: raise NotImplementedError('no such updating way of objective vectors {}'.format(name))