class AD_Trainer(nn.Module): def __init__(self, args): super(AD_Trainer, self).__init__() self.fp16 = args.fp16 self.class_balance = args.class_balance self.often_balance = args.often_balance self.num_classes = args.num_classes self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.class_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.often_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.multi_gpu = args.multi_gpu self.only_hard_label = args.only_hard_label if args.model == 'DeepLab': self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate) if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = self.G.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if args.restore_from[:4] == 'http' : if i_parts[1] !='fc' and i_parts[1] !='layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) else: #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] if i_parts[0] =='module': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) else: new_params['.'.join(i_parts[0:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[0:]) self.G.load_state_dict(new_params) self.D1 = MsImageDis(input_dim = args.num_classes).cuda() self.D2 = MsImageDis(input_dim = args.num_classes).cuda() self.D1.apply(weights_init('gaussian')) self.D2.apply(weights_init('gaussian')) if self.multi_gpu and args.sync_bn: print("using apex synced BN") self.G = apex.parallel.convert_syncbn_model(self.G) self.gen_opt = optim.SGD(self.G.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) self.kl_loss = nn.KLDivLoss(size_average=False) self.sm = torch.nn.Softmax(dim = 1) self.log_sm = torch.nn.LogSoftmax(dim = 1) self.G = self.G.cuda() self.D1 = self.D1.cuda() self.D2 = self.D2.cuda() self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) self.lambda_seg = args.lambda_seg self.max_value = args.max_value self.lambda_me_target = args.lambda_me_target self.lambda_kl_target = args.lambda_kl_target self.lambda_adv_target1 = args.lambda_adv_target1 self.lambda_adv_target2 = args.lambda_adv_target2 self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 if args.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1") def update_class_criterion(self, labels): weight = torch.FloatTensor(self.num_classes).zero_().cuda() weight += 1 count = torch.FloatTensor(self.num_classes).zero_().cuda() often = torch.FloatTensor(self.num_classes).zero_().cuda() often += 1 n, h, w = labels.shape for i in range(self.num_classes): count[i] = torch.sum(labels==i) if count[i] < 64*64*n: #small objective, original train size is 512*256 weight[i] = self.max_value if self.often_balance: often[count == 0] = self.max_value self.often_weight = 0.9 * self.often_weight + 0.1 * often self.class_weight = weight * self.often_weight print(self.class_weight) return nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255) def update_class_criterion_t(self, labels): weight = torch.FloatTensor(self.num_classes).zero_().cuda() weight += 1 count = torch.FloatTensor(self.num_classes).zero_().cuda() often = torch.FloatTensor(self.num_classes).zero_().cuda() often += 1 n, h, w = labels.shape for i in range(self.num_classes): count[i] = torch.sum(labels==i) if count[i] < 64*64*n: #small objective, original train size is 512*256 weight[i] = self.max_value if self.often_balance: often[count == 0] = self.max_value self.often_weight_t = 0.9 * self.often_weight_t + 0.1 * often self.class_weight_t = weight * self.often_weight_t print(self.class_weight_t) return nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255) def update_label(self, labels, prediction): criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') #criterion = self.seg_loss loss = criterion(prediction, labels) print('original loss: %f'% self.seg_loss(prediction, labels) ) #mm = torch.median(loss) loss_data = loss.data.cpu().numpy() mm = np.percentile(loss_data[:], self.only_hard_label) #print(m.data.cpu(), mm) labels[loss < mm] = 255 return labels def update_variance(self, labels, pred1, pred2): criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') kl_distance = nn.KLDivLoss( reduction = 'none') loss = criterion(pred1, labels) variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) exp_variance = torch.exp(-variance) print(variance.shape) print('variance mean: %.4f'%torch.mean(exp_variance[:])) print('variance min: %.4f'%torch.min(exp_variance[:])) print('variance max: %.4f'%torch.max(exp_variance[:])) loss = torch.mean(loss*exp_variance) + torch.mean(variance) return loss def update_variance_t(self, labels, pred1, pred2): criterion = nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255, reduction = 'none') kl_distance = nn.KLDivLoss( reduction = 'none') loss = criterion(pred1, labels) variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) exp_variance = torch.exp(-variance) print(variance.shape) print('variance mean: %.4f'%torch.mean(exp_variance[:])) print('variance min: %.4f'%torch.min(exp_variance[:])) print('variance max: %.4f'%torch.max(exp_variance[:])) loss = torch.mean(loss*exp_variance) + torch.mean(variance) return loss def update_loss(self, loss): if self.fp16: with amp.scale_loss(loss, self.gen_opt) as scaled_loss: scaled_loss.backward() else: loss.backward() def gen_update(self, images, images_t, labels, labels_t, i_iter): self.gen_opt.zero_grad() pred1, pred2 = self.G(images) pred1 = self.interp(pred1) pred2 = self.interp(pred2) if self.class_balance: self.seg_loss = self.update_class_criterion(labels) # calculate seg loss weighted by kldivloss # loss_seg1 = self.update_variance(labels, pred1, pred2) # loss_seg2 = self.update_variance(labels, pred2, pred1) loss_seg1 = self.seg_loss(pred1, labels) loss_seg2 = self.seg_loss(pred2, labels) loss = loss_seg2 + self.lambda_seg * loss_seg1 self.update_loss(loss) images_t = images_t.cuda() labels_t = labels_t.long().cuda() pred1_t, pred2_t = self.G(images_t) pred1_t = self.interp(pred1_t) pred2_t = self.interp(pred2_t) if self.class_balance: self.seg_loss_t = self.update_class_criterion_t(labels_t) # calculate seg loss weighted by kldivloss loss_seg1_t = self.update_variance_t(labels_t, pred1_t, pred2_t) loss_seg2_t = self.update_variance_t(labels_t, pred2_t, pred1_t) loss = loss_seg2_t + self.lambda_seg * loss_seg1_t self.update_loss(loss) self.gen_opt.step() zero_loss = torch.zeros(1).cuda() return loss_seg1, loss_seg2, loss_seg1_t, loss_seg2_t, zero_loss, zero_loss, zero_loss, zero_loss, pred1, pred2, None, None def dis_update(self, pred1, pred2, pred_target1, pred_target2): self.dis1_opt.zero_grad() self.dis2_opt.zero_grad() pred1 = pred1.detach() pred2 = pred2.detach() pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() if self.multi_gpu: loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) else: loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) loss = loss_D1 + loss_D2 if self.fp16: with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss: scaled_loss.backward() else: loss.backward() self.dis1_opt.step() self.dis2_opt.step() return loss_D1, loss_D2
def __init__(self, args): super(AD_Trainer, self).__init__() self.fp16 = args.fp16 self.class_balance = args.class_balance self.often_balance = args.often_balance self.num_classes = args.num_classes self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.class_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.often_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.multi_gpu = args.multi_gpu self.only_hard_label = args.only_hard_label if args.model == 'DeepLab': self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate) if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = self.G.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if args.restore_from[:4] == 'http' : if i_parts[1] !='fc' and i_parts[1] !='layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) else: #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] if i_parts[0] =='module': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) else: new_params['.'.join(i_parts[0:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[0:]) self.G.load_state_dict(new_params) self.D1 = MsImageDis(input_dim = args.num_classes).cuda() self.D2 = MsImageDis(input_dim = args.num_classes).cuda() self.D1.apply(weights_init('gaussian')) self.D2.apply(weights_init('gaussian')) if self.multi_gpu and args.sync_bn: print("using apex synced BN") self.G = apex.parallel.convert_syncbn_model(self.G) self.gen_opt = optim.SGD(self.G.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) self.kl_loss = nn.KLDivLoss(size_average=False) self.sm = torch.nn.Softmax(dim = 1) self.log_sm = torch.nn.LogSoftmax(dim = 1) self.G = self.G.cuda() self.D1 = self.D1.cuda() self.D2 = self.D2.cuda() self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) self.lambda_seg = args.lambda_seg self.max_value = args.max_value self.lambda_me_target = args.lambda_me_target self.lambda_kl_target = args.lambda_kl_target self.lambda_adv_target1 = args.lambda_adv_target1 self.lambda_adv_target2 = args.lambda_adv_target2 self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 if args.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1")
class AD_Trainer(nn.Module): def __init__(self, args): super(AD_Trainer, self).__init__() self.fp16 = args.fp16 self.class_balance = args.class_balance self.often_balance = args.often_balance self.num_classes = args.num_classes self.class_weight = torch.FloatTensor( self.num_classes).zero_().cuda() + 1 self.often_weight = torch.FloatTensor( self.num_classes).zero_().cuda() + 1 self.multi_gpu = args.multi_gpu self.only_hard_label = args.only_hard_label if args.model == 'DeepLab': self.G = DeeplabMulti(num_classes=args.num_classes, use_se=args.use_se, train_bn=args.train_bn, norm_style=args.norm_style, droprate=args.droprate) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = self.G.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if args.restore_from[:4] == 'http': if i_parts[1] != 'fc' and i_parts[1] != 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n' % i_parts[1:]) else: #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] if i_parts[0] == 'module': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n' % i_parts[1:]) else: new_params['.'.join(i_parts[0:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n' % i_parts[0:]) self.G.load_state_dict(new_params) self.D1 = MsImageDis(input_dim=args.num_classes).cuda() self.D2 = MsImageDis(input_dim=args.num_classes).cuda() self.D1.apply(weights_init('gaussian')) self.D2.apply(weights_init('gaussian')) if self.multi_gpu and args.sync_bn: print("using apex synced BN") self.G = apex.parallel.convert_syncbn_model(self.G) self.gen_opt = optim.SGD(self.G.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) self.kl_loss = nn.KLDivLoss(size_average=False) self.sm = torch.nn.Softmax(dim=1) self.log_sm = torch.nn.LogSoftmax(dim=1) self.G = self.G.cuda() self.D1 = self.D1.cuda() self.D2 = self.D2.cuda() self.interp = nn.Upsample(size=args.crop_size, mode='bilinear', align_corners=True) self.interp_target = nn.Upsample(size=args.crop_size, mode='bilinear', align_corners=True) self.lambda_seg = args.lambda_seg self.max_value = args.max_value self.lambda_me_target = args.lambda_me_target self.lambda_kl_target = args.lambda_kl_target self.lambda_adv_target1 = args.lambda_adv_target1 self.lambda_adv_target2 = args.lambda_adv_target2 self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 if args.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1") def update_class_criterion(self, labels): weight = torch.FloatTensor(self.num_classes).zero_().cuda() weight += 1 count = torch.FloatTensor(self.num_classes).zero_().cuda() often = torch.FloatTensor(self.num_classes).zero_().cuda() often += 1 print(labels.shape) n, h, w = labels.shape for i in range(self.num_classes): count[i] = torch.sum(labels == i) if count[i] < 64 * 64 * n: #small objective weight[i] = self.max_value if self.often_balance: often[count == 0] = self.max_value self.often_weight = 0.9 * self.often_weight + 0.1 * often self.class_weight = weight * self.often_weight print(self.class_weight) return nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255) def update_label(self, labels, prediction): criterion = nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255, reduction='none') #criterion = self.seg_loss loss = criterion(prediction, labels) print('original loss: %f' % self.seg_loss(prediction, labels)) #mm = torch.median(loss) loss_data = loss.data.cpu().numpy() mm = np.percentile(loss_data[:], self.only_hard_label) #print(m.data.cpu(), mm) labels[loss < mm] = 255 return labels def gen_update(self, images, images_t, labels, labels_t, i_iter): self.gen_opt.zero_grad() pred1, pred2 = self.G(images) pred1 = self.interp(pred1) pred2 = self.interp(pred2) if self.class_balance: self.seg_loss = self.update_class_criterion(labels) if self.only_hard_label > 0: labels1 = self.update_label(labels.clone(), pred1) labels2 = self.update_label(labels.clone(), pred2) loss_seg1 = self.seg_loss(pred1, labels1) loss_seg2 = self.seg_loss(pred2, labels2) else: loss_seg1 = self.seg_loss(pred1, labels) loss_seg2 = self.seg_loss(pred2, labels) loss = loss_seg2 + self.lambda_seg * loss_seg1 # target pred_target1, pred_target2 = self.G(images_t) pred_target1 = self.interp_target(pred_target1) pred_target2 = self.interp_target(pred_target2) if self.multi_gpu: #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: loss_adv_target1 = self.D1.module.calc_gen_loss( self.D1, input_fake=F.softmax(pred_target1, dim=1)) loss_adv_target2 = self.D2.module.calc_gen_loss( self.D2, input_fake=F.softmax(pred_target2, dim=1)) #else: # print('skip the discriminator') # loss_adv_target1, loss_adv_target2 = 0, 0 else: #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: loss_adv_target1 = self.D1.calc_gen_loss(self.D1, input_fake=F.softmax( pred_target1, dim=1)) loss_adv_target2 = self.D2.calc_gen_loss(self.D2, input_fake=F.softmax( pred_target2, dim=1)) #else: #loss_adv_target1 = 0.0 #torch.tensor(0).cuda() #loss_adv_target2 = 0.0 #torch.tensor(0).cuda() loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2 if i_iter < 15000: self.lambda_kl_target_copy = 0 self.lambda_me_target_copy = 0 else: self.lambda_kl_target_copy = self.lambda_kl_target self.lambda_me_target_copy = self.lambda_me_target loss_me = 0.0 if self.lambda_me_target_copy > 0: confidence_map = torch.sum( self.sm(0.5 * pred_target1 + pred_target2)**2, 1).detach() loss_me = -torch.mean(confidence_map * torch.sum( self.sm(0.5 * pred_target1 + pred_target2) * self.log_sm(0.5 * pred_target1 + pred_target2), 1)) loss += self.lambda_me_target * loss_me loss_kl = 0.0 if self.lambda_kl_target_copy > 0: n, c, h, w = pred_target1.shape with torch.no_grad(): #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t)) #pred_target1_flip = self.interp_target(pred_target1_flip) #pred_target2_flip = self.interp_target(pred_target2_flip) mean_pred = self.sm( 0.5 * pred_target1 + pred_target2 ) #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2 loss_kl = (self.kl_loss(self.log_sm(pred_target2), mean_pred) + self.kl_loss(self.log_sm(pred_target1), mean_pred)) / ( n * h * w) #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w) print(loss_kl) loss += self.lambda_kl_target * loss_kl if self.fp16: with amp.scale_loss(loss, self.gen_opt) as scaled_loss: scaled_loss.backward() else: loss.backward() self.gen_opt.step() val_loss = self.seg_loss(pred_target2, labels_t) return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss def dis_update(self, pred1, pred2, pred_target1, pred_target2): self.dis1_opt.zero_grad() self.dis2_opt.zero_grad() pred1 = pred1.detach() pred2 = pred2.detach() pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() if self.multi_gpu: loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake=F.softmax(pred_target1, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake=F.softmax(pred_target2, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) else: loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake=F.softmax(pred_target1, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake=F.softmax(pred_target2, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) loss = loss_D1 + loss_D2 if self.fp16: with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss: scaled_loss.backward() else: loss.backward() self.dis1_opt.step() self.dis2_opt.step() return loss_D1, loss_D2