class AD_Trainer(nn.Module): def __init__(self, args): super(AD_Trainer, self).__init__() self.fp16 = args.fp16 self.class_balance = args.class_balance self.often_balance = args.often_balance self.num_classes = args.num_classes self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.class_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.often_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 self.multi_gpu = args.multi_gpu self.only_hard_label = args.only_hard_label if args.model == 'DeepLab': self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate) if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = self.G.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if args.restore_from[:4] == 'http' : if i_parts[1] !='fc' and i_parts[1] !='layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) else: #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] if i_parts[0] =='module': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) else: new_params['.'.join(i_parts[0:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n'%i_parts[0:]) self.G.load_state_dict(new_params) self.D1 = MsImageDis(input_dim = args.num_classes).cuda() self.D2 = MsImageDis(input_dim = args.num_classes).cuda() self.D1.apply(weights_init('gaussian')) self.D2.apply(weights_init('gaussian')) if self.multi_gpu and args.sync_bn: print("using apex synced BN") self.G = apex.parallel.convert_syncbn_model(self.G) self.gen_opt = optim.SGD(self.G.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) self.kl_loss = nn.KLDivLoss(size_average=False) self.sm = torch.nn.Softmax(dim = 1) self.log_sm = torch.nn.LogSoftmax(dim = 1) self.G = self.G.cuda() self.D1 = self.D1.cuda() self.D2 = self.D2.cuda() self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) self.lambda_seg = args.lambda_seg self.max_value = args.max_value self.lambda_me_target = args.lambda_me_target self.lambda_kl_target = args.lambda_kl_target self.lambda_adv_target1 = args.lambda_adv_target1 self.lambda_adv_target2 = args.lambda_adv_target2 self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 if args.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1") def update_class_criterion(self, labels): weight = torch.FloatTensor(self.num_classes).zero_().cuda() weight += 1 count = torch.FloatTensor(self.num_classes).zero_().cuda() often = torch.FloatTensor(self.num_classes).zero_().cuda() often += 1 n, h, w = labels.shape for i in range(self.num_classes): count[i] = torch.sum(labels==i) if count[i] < 64*64*n: #small objective, original train size is 512*256 weight[i] = self.max_value if self.often_balance: often[count == 0] = self.max_value self.often_weight = 0.9 * self.often_weight + 0.1 * often self.class_weight = weight * self.often_weight print(self.class_weight) return nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255) def update_class_criterion_t(self, labels): weight = torch.FloatTensor(self.num_classes).zero_().cuda() weight += 1 count = torch.FloatTensor(self.num_classes).zero_().cuda() often = torch.FloatTensor(self.num_classes).zero_().cuda() often += 1 n, h, w = labels.shape for i in range(self.num_classes): count[i] = torch.sum(labels==i) if count[i] < 64*64*n: #small objective, original train size is 512*256 weight[i] = self.max_value if self.often_balance: often[count == 0] = self.max_value self.often_weight_t = 0.9 * self.often_weight_t + 0.1 * often self.class_weight_t = weight * self.often_weight_t print(self.class_weight_t) return nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255) def update_label(self, labels, prediction): criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') #criterion = self.seg_loss loss = criterion(prediction, labels) print('original loss: %f'% self.seg_loss(prediction, labels) ) #mm = torch.median(loss) loss_data = loss.data.cpu().numpy() mm = np.percentile(loss_data[:], self.only_hard_label) #print(m.data.cpu(), mm) labels[loss < mm] = 255 return labels def update_variance(self, labels, pred1, pred2): criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') kl_distance = nn.KLDivLoss( reduction = 'none') loss = criterion(pred1, labels) variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) exp_variance = torch.exp(-variance) print(variance.shape) print('variance mean: %.4f'%torch.mean(exp_variance[:])) print('variance min: %.4f'%torch.min(exp_variance[:])) print('variance max: %.4f'%torch.max(exp_variance[:])) loss = torch.mean(loss*exp_variance) + torch.mean(variance) return loss def update_variance_t(self, labels, pred1, pred2): criterion = nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255, reduction = 'none') kl_distance = nn.KLDivLoss( reduction = 'none') loss = criterion(pred1, labels) variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) exp_variance = torch.exp(-variance) print(variance.shape) print('variance mean: %.4f'%torch.mean(exp_variance[:])) print('variance min: %.4f'%torch.min(exp_variance[:])) print('variance max: %.4f'%torch.max(exp_variance[:])) loss = torch.mean(loss*exp_variance) + torch.mean(variance) return loss def update_loss(self, loss): if self.fp16: with amp.scale_loss(loss, self.gen_opt) as scaled_loss: scaled_loss.backward() else: loss.backward() def gen_update(self, images, images_t, labels, labels_t, i_iter): self.gen_opt.zero_grad() pred1, pred2 = self.G(images) pred1 = self.interp(pred1) pred2 = self.interp(pred2) if self.class_balance: self.seg_loss = self.update_class_criterion(labels) # calculate seg loss weighted by kldivloss # loss_seg1 = self.update_variance(labels, pred1, pred2) # loss_seg2 = self.update_variance(labels, pred2, pred1) loss_seg1 = self.seg_loss(pred1, labels) loss_seg2 = self.seg_loss(pred2, labels) loss = loss_seg2 + self.lambda_seg * loss_seg1 self.update_loss(loss) images_t = images_t.cuda() labels_t = labels_t.long().cuda() pred1_t, pred2_t = self.G(images_t) pred1_t = self.interp(pred1_t) pred2_t = self.interp(pred2_t) if self.class_balance: self.seg_loss_t = self.update_class_criterion_t(labels_t) # calculate seg loss weighted by kldivloss loss_seg1_t = self.update_variance_t(labels_t, pred1_t, pred2_t) loss_seg2_t = self.update_variance_t(labels_t, pred2_t, pred1_t) loss = loss_seg2_t + self.lambda_seg * loss_seg1_t self.update_loss(loss) self.gen_opt.step() zero_loss = torch.zeros(1).cuda() return loss_seg1, loss_seg2, loss_seg1_t, loss_seg2_t, zero_loss, zero_loss, zero_loss, zero_loss, pred1, pred2, None, None def dis_update(self, pred1, pred2, pred_target1, pred_target2): self.dis1_opt.zero_grad() self.dis2_opt.zero_grad() pred1 = pred1.detach() pred2 = pred2.detach() pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() if self.multi_gpu: loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) else: loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) loss = loss_D1 + loss_D2 if self.fp16: with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss: scaled_loss.backward() else: loss.backward() self.dis1_opt.step() self.dis2_opt.step() return loss_D1, loss_D2
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = args.gpu criterion = DiceBCELoss() # criterion = nn.CrossEntropyLoss(ignore_index=253) # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from is None: pass elif args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) elif args.restore_from is not None: saved_state_dict = torch.load(args.restore_from) if args.restore_from is not None: new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) if not args.no_logging: if not os.path.isdir(args.log_dir): os.mkdir(args.log_dir) log_dir = os.path.join(args.log_dir, args.exp_dir) if not os.path.isdir(log_dir): os.mkdir(log_dir) if args.exp_name == "": exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d") else: exp_name = args.exp_name log_dir = os.path.join(log_dir, exp_name) writer = SummaryWriter(log_dir) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda(args.gpu) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(SyntheticSmokeTrain( args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size, image_shape=input_size, dataset_mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) print("Length of train dataloader: ", len(trainloader)) targetloader = data.DataLoader(SimpleSmokeVal(args={}, image_size=input_size_target, dataset_mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) print("Length of train dataloader: ", len(targetloader)) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() # bce_loss_all = torch.nn.BCEWithLogitsLoss(reduction='none') elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() # bce_loss_all = torch.nn.MSELoss(reduction='none') interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # interp_domain = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False for param in interp_domain.parameters(): param.requires_grad = False # train with source # try: _, batch = next(trainloader_iter) #.next() # except StopIteration: # trainloader = data.DataLoader( # SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size, # image_shape=input_size, dataset_mean=IMG_MEAN), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # trainloader_iter = iter(trainloader) # _, batch = next(trainloader_iter) images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) # print("Shape of labels", labels.shape) # print("Are labels all zero? ") # for i in range(labels.shape[0]): # print("{}: All zero? {}".format(i, torch.all(labels[i]==0))) # print("{}: All 255? {}".format(i, torch.all(labels[i]==255))) # print("{}: Mean = {}".format(i, torch.mean(labels[i]))) pred1, pred2 = model(images) # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape)) pred1 = interp(pred1) pred2 = interp(pred2) # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape)) # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']): # print(name) # for i in range(pred.shape[0]): # print("{}: All zero? {}".format(i, torch.all(pred[i]==0))) # print("{}: All 255? {}".format(i, torch.all(pred[i]==255))) # print("{}: Mean = {}".format(i, torch.mean(pred[i]))) loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion) loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() # print("Seg1 loss: ",loss_seg1, args.iter_size) # print("Seg2 loss: ",loss_seg2, args.iter_size) loss_seg_value1 += loss_seg1.detach().data.cpu().item( ) / args.iter_size loss_seg_value2 += loss_seg2.detach().data.cpu().item( ) / args.iter_size # train with target # try: _, batch = next(targetloader_iter) #.next() # except StopIteration: # targetloader = data.DataLoader( # SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, # pin_memory=True) # targetloader_iter = iter(targetloader) # _, batch = next(targetloader_iter) images, _, _ = batch images = Variable(images).cuda(args.gpu) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) # w1 = torch.argmax(pred_target1.detach(), dim=1) # w2 = torch.argmax(pred_target2.detach(), dim=1) min_class1 = sorted([(k, v) for k, v in Counter(w1.ravel()).items()], key=lambda x: x[1])[0][0] min_class2 = sorted([(k, v) for k, v in Counter(w2.ravel()).items()], key=lambda x: x[1])[0][0] # m1 = torch.where(w1==min_class1) # m1c = torch.where(w1!=min_class1) # w1[m1] = 11 # w1[m1c] = 1 # m2 = torch.where(w2==min_class2) # m2c = torch.where(w2!=min_class2) # w2[m2] = 11 # w2[m2c] = 1 # D_out1 = interp_domain(D_out1) # D_out2 = interp_domain(D_out2) loss_adv_target1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.detach().data.cpu( ).item() / args.iter_size loss_adv_target_value2 += loss_adv_target2.detach().data.cpu( ).item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1, dim=1)) D_out2 = model_D2(F.softmax(pred2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.detach().data.cpu().item() loss_D_value2 += loss_D2.detach().data.cpu().item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(target_label)).cuda( args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda( args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.detach().data.cpu().item() loss_D_value2 += loss_D2.detach().data.cpu().item() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1, i_iter) writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2, i_iter) writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1, i_iter) writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2, i_iter) writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter) writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'lmda_adv_0.1_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join( args.snapshot_dir, 'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join( args.snapshot_dir, 'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'lmda_adv_0.1_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'lmda_adv_0.1_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'lmda_adv_0.1_' + str(i_iter) + '_D2.pth')) writer.flush()
class AD_Trainer(nn.Module): def __init__(self, args): super(AD_Trainer, self).__init__() self.fp16 = args.fp16 self.class_balance = args.class_balance self.often_balance = args.often_balance self.num_classes = args.num_classes self.class_weight = torch.FloatTensor( self.num_classes).zero_().cuda() + 1 self.often_weight = torch.FloatTensor( self.num_classes).zero_().cuda() + 1 self.multi_gpu = args.multi_gpu self.only_hard_label = args.only_hard_label if args.model == 'DeepLab': self.G = DeeplabMulti(num_classes=args.num_classes, use_se=args.use_se, train_bn=args.train_bn, norm_style=args.norm_style, droprate=args.droprate) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = self.G.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if args.restore_from[:4] == 'http': if i_parts[1] != 'fc' and i_parts[1] != 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n' % i_parts[1:]) else: #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] if i_parts[0] == 'module': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n' % i_parts[1:]) else: new_params['.'.join(i_parts[0:])] = saved_state_dict[i] print('%s is loaded from pre-trained weight.\n' % i_parts[0:]) self.G.load_state_dict(new_params) self.D1 = MsImageDis(input_dim=args.num_classes).cuda() self.D2 = MsImageDis(input_dim=args.num_classes).cuda() self.D1.apply(weights_init('gaussian')) self.D2.apply(weights_init('gaussian')) if self.multi_gpu and args.sync_bn: print("using apex synced BN") self.G = apex.parallel.convert_syncbn_model(self.G) self.gen_opt = optim.SGD(self.G.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) self.kl_loss = nn.KLDivLoss(size_average=False) self.sm = torch.nn.Softmax(dim=1) self.log_sm = torch.nn.LogSoftmax(dim=1) self.G = self.G.cuda() self.D1 = self.D1.cuda() self.D2 = self.D2.cuda() self.interp = nn.Upsample(size=args.crop_size, mode='bilinear', align_corners=True) self.interp_target = nn.Upsample(size=args.crop_size, mode='bilinear', align_corners=True) self.lambda_seg = args.lambda_seg self.max_value = args.max_value self.lambda_me_target = args.lambda_me_target self.lambda_kl_target = args.lambda_kl_target self.lambda_adv_target1 = args.lambda_adv_target1 self.lambda_adv_target2 = args.lambda_adv_target2 self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 if args.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1") def update_class_criterion(self, labels): weight = torch.FloatTensor(self.num_classes).zero_().cuda() weight += 1 count = torch.FloatTensor(self.num_classes).zero_().cuda() often = torch.FloatTensor(self.num_classes).zero_().cuda() often += 1 print(labels.shape) n, h, w = labels.shape for i in range(self.num_classes): count[i] = torch.sum(labels == i) if count[i] < 64 * 64 * n: #small objective weight[i] = self.max_value if self.often_balance: often[count == 0] = self.max_value self.often_weight = 0.9 * self.often_weight + 0.1 * often self.class_weight = weight * self.often_weight print(self.class_weight) return nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255) def update_label(self, labels, prediction): criterion = nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255, reduction='none') #criterion = self.seg_loss loss = criterion(prediction, labels) print('original loss: %f' % self.seg_loss(prediction, labels)) #mm = torch.median(loss) loss_data = loss.data.cpu().numpy() mm = np.percentile(loss_data[:], self.only_hard_label) #print(m.data.cpu(), mm) labels[loss < mm] = 255 return labels def gen_update(self, images, images_t, labels, labels_t, i_iter): self.gen_opt.zero_grad() pred1, pred2 = self.G(images) pred1 = self.interp(pred1) pred2 = self.interp(pred2) if self.class_balance: self.seg_loss = self.update_class_criterion(labels) if self.only_hard_label > 0: labels1 = self.update_label(labels.clone(), pred1) labels2 = self.update_label(labels.clone(), pred2) loss_seg1 = self.seg_loss(pred1, labels1) loss_seg2 = self.seg_loss(pred2, labels2) else: loss_seg1 = self.seg_loss(pred1, labels) loss_seg2 = self.seg_loss(pred2, labels) loss = loss_seg2 + self.lambda_seg * loss_seg1 # target pred_target1, pred_target2 = self.G(images_t) pred_target1 = self.interp_target(pred_target1) pred_target2 = self.interp_target(pred_target2) if self.multi_gpu: #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: loss_adv_target1 = self.D1.module.calc_gen_loss( self.D1, input_fake=F.softmax(pred_target1, dim=1)) loss_adv_target2 = self.D2.module.calc_gen_loss( self.D2, input_fake=F.softmax(pred_target2, dim=1)) #else: # print('skip the discriminator') # loss_adv_target1, loss_adv_target2 = 0, 0 else: #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: loss_adv_target1 = self.D1.calc_gen_loss(self.D1, input_fake=F.softmax( pred_target1, dim=1)) loss_adv_target2 = self.D2.calc_gen_loss(self.D2, input_fake=F.softmax( pred_target2, dim=1)) #else: #loss_adv_target1 = 0.0 #torch.tensor(0).cuda() #loss_adv_target2 = 0.0 #torch.tensor(0).cuda() loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2 if i_iter < 15000: self.lambda_kl_target_copy = 0 self.lambda_me_target_copy = 0 else: self.lambda_kl_target_copy = self.lambda_kl_target self.lambda_me_target_copy = self.lambda_me_target loss_me = 0.0 if self.lambda_me_target_copy > 0: confidence_map = torch.sum( self.sm(0.5 * pred_target1 + pred_target2)**2, 1).detach() loss_me = -torch.mean(confidence_map * torch.sum( self.sm(0.5 * pred_target1 + pred_target2) * self.log_sm(0.5 * pred_target1 + pred_target2), 1)) loss += self.lambda_me_target * loss_me loss_kl = 0.0 if self.lambda_kl_target_copy > 0: n, c, h, w = pred_target1.shape with torch.no_grad(): #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t)) #pred_target1_flip = self.interp_target(pred_target1_flip) #pred_target2_flip = self.interp_target(pred_target2_flip) mean_pred = self.sm( 0.5 * pred_target1 + pred_target2 ) #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2 loss_kl = (self.kl_loss(self.log_sm(pred_target2), mean_pred) + self.kl_loss(self.log_sm(pred_target1), mean_pred)) / ( n * h * w) #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w) print(loss_kl) loss += self.lambda_kl_target * loss_kl if self.fp16: with amp.scale_loss(loss, self.gen_opt) as scaled_loss: scaled_loss.backward() else: loss.backward() self.gen_opt.step() val_loss = self.seg_loss(pred_target2, labels_t) return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss def dis_update(self, pred1, pred2, pred_target1, pred_target2): self.dis1_opt.zero_grad() self.dis2_opt.zero_grad() pred1 = pred1.detach() pred2 = pred2.detach() pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() if self.multi_gpu: loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake=F.softmax(pred_target1, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake=F.softmax(pred_target2, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) else: loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake=F.softmax(pred_target1, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake=F.softmax(pred_target2, dim=1), input_real=F.softmax(0.5 * pred1 + pred2, dim=1)) loss = loss_D1 + loss_D2 if self.fp16: with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss: scaled_loss.backward() else: loss.backward() self.dis1_opt.step() self.dis2_opt.step() return loss_D1, loss_D2
def main(): """Create the model and start the training.""" global args args = get_arguments() if args.dist: init_dist(args.launcher, backend=args.backend) world_size = 1 rank = 0 if args.dist: rank = dist.get_rank() world_size = dist.get_world_size() device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'Deeplab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from, strict=False) new_params = model.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] model.load_state_dict(new_params) elif args.model == 'DeeplabVGG': model = DeeplabVGG(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict, strict=False) elif args.model == 'DeeplabVGGBN': deeplab_vggbn.BatchNorm = SyncBatchNorm2d model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict, strict=False) del saved_state_dict model.train() model.to(device) if args.dist: broadcast_params(model) if rank == 0: print(model) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) if args.dist: broadcast_params(model_D1) if args.restore_D is not None: D_dict = torch.load(args.restore_D) model_D1.load_state_dict(D_dict, strict=False) del D_dict model_D2.train() model_D2.to(device) if args.dist: broadcast_params(model_D2) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_data = GTA5BDDDataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_sampler = None if args.dist: train_sampler = DistributedSampler(train_data) trainloader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False if train_sampler else True, num_workers=args.num_workers, pin_memory=False, sampler=train_sampler) trainloader_iter = enumerate(cycle(trainloader)) target_data = BDDDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set) target_sampler = None if args.dist: target_sampler = DistributedSampler(target_data) targetloader = data.DataLoader(target_data, batch_size=args.batch_size, shuffle=False if target_sampler else True, num_workers=args.num_workers, pin_memory=False, sampler=target_sampler) targetloader_iter = enumerate(cycle(targetloader)) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard and rank == 0: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) torch.cuda.empty_cache() for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, size, _ = batch images = images.to(device) labels = labels.long().to(device) interp = nn.Upsample(size=(size[1], size[0]), mode='bilinear', align_corners=True) pred1 = model(images) pred1 = interp(pred1) loss_seg1 = seg_loss(pred1, labels) loss = loss_seg1 # proper normalization loss = loss / args.iter_size / world_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size _, batch = targetloader_iter.__next__() # train with target images, _, _ = batch images = images.to(device) pred_target1 = model(images) pred_target1 = interp_target(pred_target1) D_out1 = model_D1(F.softmax(pred_target1)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 loss = loss / args.iter_size / world_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() D_out1 = model_D1(F.softmax(pred1)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 / world_size loss_D1.backward() loss_D_value1 += loss_D1.item() # train with target pred_target1 = pred_target1.detach() D_out1 = model_D1(F.softmax(pred_target1)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 / world_size loss_D1.backward() if args.dist: average_gradients(model) average_gradients(model_D1) average_gradients(model_D2) loss_D_value1 += loss_D1.item() optimizer.step() optimizer_D1.step() if rank == 0: if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1 * world_size, 'loss_D2': loss_D_value2 * world_size, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) print(args.snapshot_dir) if args.tensorboard and rank == 0: writer.close()
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = DeeplabMulti(num_classes=args.num_classes) #model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from, map_location='cuda:0') # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) #summary(model,(3,7,7)) cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) #summary(model_D, (21,321,321)) #quit() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = cityscapesDataSet(max_iters=args.num_steps * args.iter_size * args.batch_size, scale=args.random_scale) train_dataset_size = len(train_dataset) train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps * args.iter_size * args.batch_size, scale=args.random_scale) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: # sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = list(range(train_dataset_size)) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, sampler=train_remain_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, sampler=train_gt_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) trainloader_remain_iter = enumerate(trainloader) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.__next__() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.__next__() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) try: pred = interp(model(images)) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception pred_remain = pred.detach() D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = args.lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / args.iter_size #loss_semi_adv.backward() loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( ) / args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc( pred, semi_gt, args.gpu) loss_semi = loss_semi / args.iter_size loss_semi_value += loss_semi.data.cpu().numpy( ) / args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) try: pred = interp(model(images)) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception loss_seg = loss_calc(pred, labels, args.gpu) D_out = interp(model_D(F.softmax(pred))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain), axis=0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train with gt # get gt labels try: _, batch = trainloader_gt_iter.__next__() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.__next__() _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth')) #torch.cuda.empty_cache() end = timeit.default_timer() print(end - start, 'seconds')
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = args.gpu # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) # if args.restore_from[:4] == 'http' : # saved_state_dict = model_zoo.load_url(args.restore_from) # else: # saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() # for i in saved_state_dict: # # Scale.layer5.conv2d_list.3.weight # i_parts = i.split('.') # # print i_parts # if not args.num_classes == 19 or not i_parts[1] == 'layer5': # new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # # print i_parts model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda(args.gpu) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # trainloader = data.DataLoader( # GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) train_set = spacenet.Spacenet(city=config.dataset, split='train', img_root=config.img_root) trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) trainloader_iter = enumerate(trainloader) # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target, # max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size_target, # scale=False, mirror=args.random_mirror, mean=IMG_MEAN, # set=args.set), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, # pin_memory=True) target_set = spacenet.Spacenet(city=config.target, split='train', img_root=config.img_root) targetloader = DataLoader(target_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear') # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source try: _, batch = trainloader_iter.__next__() except StopIteration: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.__next__() images, labels = batch # print(images) images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() # loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size # loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size # train with target try: _, batch = targetloader_iter.__next__() except StopIteration: targetloader_iter = enumerate(targetloader) _, batch = targetloader_iter.__next__() images, _ = batch images = Variable(images).cuda(args.gpu) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss(D_out1, Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_adv_target2 = bce_loss(D_out2, Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() # loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy()[0] / args.iter_size # loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy()[0] / args.iter_size loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss(D_out1, Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D2 = bce_loss(D_out2, Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss(D_out1, Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D2 = bce_loss(D_out2, Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format( i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '.pth')) torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D1.pth')) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '.pth')) torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D1.pth')) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D2.pth'))
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = args.gpu tau = torch.ones(1) * args.tau tau = tau.cuda(args.gpu) # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params, False) elif args.model == 'DeepLabVGG': model = DeeplabVGG(pretrained=True, num_classes=args.num_classes) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda(args.gpu) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] weak_transform = transforms.Compose([ # transforms.RandomCrop(32, 4), # transforms.RandomRotation(30), # transforms.Resize(1024), transforms.ToTensor(), # transforms.Normalize(mean, std), # RandomCrop(768) ]) target_transform = transforms.Compose([ # transforms.RandomCrop(32, 4), # transforms.RandomRotation(30), # transforms.Normalize(mean, std) # transforms.Resize(1024), # transforms.ToTensor(), # RandomCrop(768) ]) label_set = GTA5( root=args.data_dir, num_cls=19, split='all', remap_labels=True, transform=weak_transform, target_transform=target_transform, scale=input_size, # crop_transform=RandomCrop(int(768*(args.scale/1024))), ) unlabel_set = Cityscapes( root=args.data_dir_target, split=args.set, remap_labels=True, transform=weak_transform, target_transform=target_transform, scale=input_size_target, # crop_transform=RandomCrop(int(768*(args.scale/1024))), ) test_set = Cityscapes( root=args.data_dir_target, split='val', remap_labels=True, transform=weak_transform, target_transform=target_transform, scale=input_size_target, # crop_transform=RandomCrop(768) ) label_loader = data.DataLoader(label_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False) unlabel_loader = data.DataLoader(unlabel_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False) test_loader = data.DataLoader(test_set, batch_size=2, shuffle=False, num_workers=args.num_workers, pin_memory=False) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) [model, model_D2, model_D2], [optimizer, optimizer_D1, optimizer_D2 ] = amp.initialize([model, model_D2, model_D2], [optimizer, optimizer_D1, optimizer_D2], opt_level="O1", num_losses=7) optimizer.zero_grad() optimizer_D1.zero_grad() optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() interp = Interpolate(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = Interpolate(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) interp_test = Interpolate(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True) normalize_transform = transforms.Compose([ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # labels for adversarial training source_label = 0 target_label = 1 max_mIoU = 0 total_loss_seg_value1 = [] total_loss_adv_target_value1 = [] total_loss_D_value1 = [] total_loss_con_value1 = [] total_loss_seg_value2 = [] total_loss_adv_target_value2 = [] total_loss_D_value2 = [] total_loss_con_value2 = [] hist = np.zeros((num_cls, num_cls)) # for i_iter in range(args.num_steps): for i_iter, (batch, batch_un) in enumerate( zip(roundrobin_infinite(label_loader), roundrobin_infinite(unlabel_loader))): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_con_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 loss_con_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source images, labels = batch images_orig = images images = transform_batch(images, normalize_transform) images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss: scaled_loss.backward() # loss.backward() loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size # train with target images_tar, labels_tar = batch_un images_tar_orig = images_tar images_tar = transform_batch(images_tar, normalize_transform) images_tar = Variable(images_tar).cuda(args.gpu) pred_target1, pred_target2 = model(images_tar) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_adv_target1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda(args.gpu)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda(args.gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss: scaled_loss.backward() # loss.backward() loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy( ) / args.iter_size # train with consistency loss # unsupervise phase policies = RandAugment().get_batch_policy(args.batch_size) rand_p1 = np.random.random(size=args.batch_size) rand_p2 = np.random.random(size=args.batch_size) random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2]) images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1, rand_p2, random_dir) images_aug_orig = images_aug images_aug = transform_batch(images_aug, normalize_transform) images_aug = Variable(images_aug).cuda(args.gpu) pred_target_aug1, pred_target_aug2 = model(images_aug) pred_target_aug1 = interp_target(pred_target_aug1) pred_target_aug2 = interp_target(pred_target_aug2) pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1) max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1) psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32) psuedo_label1_thre = psuedo_label1.copy() psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype( np.bool)] = 255 # threshold to don't care psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies, rand_p1, rand_p2, random_dir) psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32) psuedo_label2_thre = psuedo_label2.copy() psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype( np.bool)] = 255 # threshold to don't care psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies, rand_p1, rand_p2, random_dir) psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu) psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu) if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0: # nll_loss doesn't support empty tensors loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre, args.gpu) loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size else: loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu) if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0: # nll_loss doesn't support empty tensors loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre, args.gpu) loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size else: loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu) loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2 # proper normalization loss = loss / args.iter_size with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss: scaled_loss.backward() # loss.backward() # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1, dim=1)) D_out2 = model_D2(F.softmax(pred2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss: scaled_loss.backward() # loss_D1.backward() with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss: scaled_loss.backward() # loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss: scaled_loss.backward() # loss_D1.backward() with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss: scaled_loss.backward() # loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, loss_con_value1, loss_con_value2)) total_loss_seg_value1.append(loss_seg_value1) total_loss_adv_target_value1.append(loss_adv_target_value1) total_loss_D_value1.append(loss_D_value1) total_loss_con_value1.append(loss_con_value1) total_loss_seg_value2.append(loss_seg_value2) total_loss_adv_target_value2.append(loss_adv_target_value2) total_loss_D_value2.append(loss_D_value2) total_loss_con_value2.append(loss_con_value2) hist += fast_hist( labels.cpu().numpy().flatten().astype(int), torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int), num_cls) if i_iter % 10 == 0: print('({}/{})'.format(i_iter + 1, int(args.num_steps))) acc_overall, acc_percls, iu, fwIU = result_stats(hist) mIoU = np.mean(iu) per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))] per_class = np.array(per_class).flatten() print( ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class)) print('mIoU : {:0.2f}'.format(np.mean(iu))) print('fwIoU : {:0.2f}'.format(fwIU)) print('pixel acc : {:0.2f}'.format(acc_overall)) per_class = [[classes[i], acc] for i, acc in list(enumerate(acc_percls))] per_class = np.array(per_class).flatten() print( ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class)) avg_train_acc = acc_overall avg_train_loss_seg1 = np.mean(total_loss_seg_value1) avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1) avg_train_loss_dis1 = np.mean(total_loss_D_value1) avg_train_loss_con1 = np.mean(total_loss_con_value1) avg_train_loss_seg2 = np.mean(total_loss_seg_value2) avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2) avg_train_loss_dis2 = np.mean(total_loss_D_value2) avg_train_loss_con2 = np.mean(total_loss_con_value2) print('avg_train_acc :', avg_train_acc) print('avg_train_loss_seg1 :', avg_train_loss_seg1) print('avg_train_loss_adv1 :', avg_train_loss_adv1) print('avg_train_loss_dis1 :', avg_train_loss_dis1) print('avg_train_loss_con1 :', avg_train_loss_con1) print('avg_train_loss_seg2 :', avg_train_loss_seg2) print('avg_train_loss_adv2 :', avg_train_loss_adv2) print('avg_train_loss_dis2 :', avg_train_loss_dis2) print('avg_train_loss_con2 :', avg_train_loss_con2) writer['train'].add_scalar('log/mIoU', mIoU, i_iter) writer['train'].add_scalar('log/acc', avg_train_acc, i_iter) writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1, i_iter) writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1, i_iter) writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1, i_iter) writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1, i_iter) writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2, i_iter) writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2, i_iter) writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2, i_iter) writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2, i_iter) hist = np.zeros((num_cls, num_cls)) total_loss_seg_value1 = [] total_loss_adv_target_value1 = [] total_loss_D_value1 = [] total_loss_con_value1 = [] total_loss_seg_value2 = [] total_loss_adv_target_value2 = [] total_loss_D_value2 = [] total_loss_con_value2 = [] fig = plt.figure(figsize=(15, 15)) labels = labels[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(331) ax.imshow(print_palette(Image.fromarray(labels).convert('L'))) ax.axis("off") ax.set_title('labels') ax = fig.add_subplot(337) images = images_orig[0].cpu().numpy().transpose((1, 2, 0)) # images += IMG_MEAN ax.imshow(images) ax.axis("off") ax.set_title('datas') _, pred2 = torch.max(pred2, dim=1) pred2 = pred2[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(334) ax.imshow(print_palette(Image.fromarray(pred2).convert('L'))) ax.axis("off") ax.set_title('predicts') labels_tar = labels_tar[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(332) ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L'))) ax.axis("off") ax.set_title('tar_labels') ax = fig.add_subplot(338) ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0))) ax.axis("off") ax.set_title('tar_datas') _, pred_target2 = torch.max(pred_target2, dim=1) pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(335) ax.imshow(print_palette( Image.fromarray(pred_target2).convert('L'))) ax.axis("off") ax.set_title('tar_predicts') print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0], 'random_dir', random_dir[0]) psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype( np.float32) ax = fig.add_subplot(333) ax.imshow( print_palette( Image.fromarray(psuedo_label2_thre).convert('L'))) ax.axis("off") ax.set_title('psuedo_labels') ax = fig.add_subplot(339) ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0))) ax.axis("off") ax.set_title('aug_datas') _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1) pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype( np.float32) ax = fig.add_subplot(336) ax.imshow( print_palette(Image.fromarray(pred_target_aug2).convert('L'))) ax.axis("off") ax.set_title('aug_predicts') # plt.show() writer['train'].add_figure('image/', fig, global_step=i_iter, close=True) if i_iter % 500 == 0: loss1 = [] loss2 = [] for test_i, batch in enumerate(test_loader): images, labels = batch images_orig = images images = transform_batch(images, normalize_transform) images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp_test(pred1) pred1 = pred1.detach() pred2 = interp_test(pred2) pred2 = pred2.detach() loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss1.append(loss_seg1.item()) loss2.append(loss_seg2.item()) hist += fast_hist( labels.cpu().numpy().flatten().astype(int), torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int), num_cls) print('test') fig = plt.figure(figsize=(15, 15)) labels = labels[-1].cpu().numpy().astype(np.float32) ax = fig.add_subplot(311) ax.imshow(print_palette(Image.fromarray(labels).convert('L'))) ax.axis("off") ax.set_title('labels') ax = fig.add_subplot(313) ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0))) ax.axis("off") ax.set_title('datas') _, pred2 = torch.max(pred2, dim=1) pred2 = pred2[-1].cpu().numpy().astype(np.float32) ax = fig.add_subplot(312) ax.imshow(print_palette(Image.fromarray(pred2).convert('L'))) ax.axis("off") ax.set_title('predicts') # plt.show() writer['test'].add_figure('test_image/', fig, global_step=i_iter, close=True) acc_overall, acc_percls, iu, fwIU = result_stats(hist) mIoU = np.mean(iu) per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))] per_class = np.array(per_class).flatten() print( ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class)) print('mIoU : {:0.2f}'.format(mIoU)) print('fwIoU : {:0.2f}'.format(fwIU)) print('pixel acc : {:0.2f}'.format(acc_overall)) per_class = [[classes[i], acc] for i, acc in list(enumerate(acc_percls))] per_class = np.array(per_class).flatten() print( ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class)) avg_test_loss1 = np.mean(loss1) avg_test_loss2 = np.mean(loss2) avg_test_acc = acc_overall print('avg_test_loss2 :', avg_test_loss1) print('avg_test_loss1 :', avg_test_loss2) print('avg_test_acc :', avg_test_acc) writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter) writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter) writer['test'].add_scalar('log/acc', avg_test_acc, i_iter) writer['test'].add_scalar('log/mIoU', mIoU, i_iter) hist = np.zeros((num_cls, num_cls)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if max_mIoU < mIoU: max_mIoU = mIoU torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
def main(): """Create the model and start the training.""" gpu_id_2 = 1 gpu_id_1 = 0 w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # gpu = args.gpu # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': print("from url") saved_state_dict = model_zoo.load_url(args.restore_from) else: print("from restore") saved_state_dict = torch.load( '/data2/zhangjunyi/snapshots/snapshots_syn/onlysyn/GTA5_80000.pth', map_location={"cuda:3": "cuda:0"}) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if (not i_parts[1] == 'layer5') and (not i_parts[0] == 'fc'): new_params['.'.join(i_parts)] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.cuda(gpu_id_2) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes, match=False) model_D1.train() model_D1.cuda(gpu_id_1) model_D2.train() model_D2.cuda(gpu_id_1) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(SynthiaDataSet( args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) _, batch_last = trainloader_iter.__next__() targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, cut_size=cut_size, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) _, batch_last_target = targetloader_iter.__next__() # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() mse_loss = torch.nn.MSELoss() def upsample_(input_): return nn.functional.interpolate(input_, size=(input_size[1], input_size[0]), mode='bilinear', align_corners=False) def upsample_target(input_): return nn.functional.interpolate(input_, size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=False) interp = upsample_ interp_target = upsample_target # labels for adversarial training source_label = 1 target_label = -1 mix_label = 0 for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 number1 = 0 number2 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False def result_model(batch, interp_): images, labels, _, name = batch images = Variable(images).cuda(gpu_id_2) labels = Variable(labels.long()).cuda(gpu_id_1) pred1, pred2 = model(images) pred1 = interp_(pred1) pred2 = interp_(pred2) pred1_ = pred1.cuda(gpu_id_1) pred2_ = pred2.cuda(gpu_id_1) return pred1_, pred2_, labels beta = args.beta if i_iter == 0: print(beta) _, batch = trainloader_iter.__next__() _, batch_target = targetloader_iter.__next__() pred1, pred2, labels = result_model(batch, interp) loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta) labels = new_labels number1 = torch.sum(labels == 255).item() loss_seg2, new_labels = loss_calc(pred2, labels, gpu_id_1, beta) loss = loss_seg2 + args.lambda_seg * loss_seg1 loss = loss / args.iter_size loss.backward() loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size pred1, pred2, labels = result_model(batch_target, interp_target) loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta) labels = new_labels number2 = torch.sum(labels == 255).item() loss_seg2, new_lables = loss_calc(pred2, labels, gpu_id_1, beta) loss = loss_seg2 + args.lambda_seg * loss_seg1 loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size pred1_last_target, pred2_last_target, labels_last_target = result_model( batch_last_target, interp_target) pred1_target, pred2_target, labels_target = result_model( batch_target, interp_target) # exit() pred1_target_D = F.softmax((pred1_target), dim=1) pred2_target_D = F.softmax((pred2_target), dim=1) pred1_last_target_D = F.softmax((pred1_last_target), dim=1) pred2_last_target_D = F.softmax((pred2_last_target), dim=1) fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1) fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1) D_out_fake_1 = model_D1(fake1_D) D_out_fake_2 = model_D1(fake2_D) loss_adv_fake1 = mse_loss( D_out_fake_1, Variable( torch.FloatTensor(D_out_fake_1.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_adv_fake2 = mse_loss( D_out_fake_2, Variable( torch.FloatTensor(D_out_fake_2.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_adv_target1 = loss_adv_fake1 loss_adv_target2 = loss_adv_fake2 loss = args.lambda_adv_target1 * loss_adv_target1.cuda( gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda( gpu_id_1) loss = loss / args.iter_size loss.backward() pred1, pred2, labels = result_model(batch, interp) pred1_target, pred2_target, labels_target = result_model( batch_target, interp_target) pred1_target_D = F.softmax((pred1_target), dim=1) pred2_target_D = F.softmax((pred2_target), dim=1) pred1_D = F.softmax((pred1), dim=1) pred2_D = F.softmax((pred2), dim=1) mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1) mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1) D_out_mix_1 = model_D1(mix1_D) D_out_mix_2 = model_D1(mix2_D) loss_adv_mix1 = mse_loss( D_out_mix_1, Variable( torch.FloatTensor(D_out_mix_1.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_adv_mix2 = mse_loss( D_out_mix_2, Variable( torch.FloatTensor(D_out_mix_2.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_adv_target1 = loss_adv_mix1 * 2 loss_adv_target2 = loss_adv_mix2 * 2 loss = args.lambda_adv_target1 * loss_adv_target1.cuda( gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda( gpu_id_1) loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy( ) / args.iter_size loss_adv_target_value1 += loss_adv_target2.data.cpu().numpy( ) / args.iter_size # train D2 pred1_target, pred2_target, labels_target = result_model( batch_target, interp_target) pred1_target_D = F.softmax((pred1_target), dim=1) pred2_target_D = F.softmax((pred2_target), dim=1) D_out_target_1 = model_D2(pred1_target_D) D_out_target_2 = model_D2(pred2_target_D) loss_adv_target1 = bce_loss( D_out_target_1, Variable( torch.FloatTensor(D_out_target_1.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_adv_target2 = bce_loss( D_out_target_2, Variable( torch.FloatTensor(D_out_target_2.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss = args.lambda_adv_target1 * loss_adv_target1.cuda( gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda( gpu_id_1) loss = loss / args.iter_size loss.backward() loss_adv_target_value2 += loss_adv_target1.data.cpu().numpy( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy( ) / args.iter_size # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True pred1_last, pred2_last, labels_last = result_model( batch_last, interp) # train with source pred1 = pred1.detach().cuda(gpu_id_1) pred2 = pred2.detach().cuda(gpu_id_1) pred1_target = pred1_target.detach().cuda(gpu_id_1) pred2_target = pred2_target.detach().cuda(gpu_id_1) pred1_last = pred1_last.detach().cuda(gpu_id_1) pred2_last = pred2_last.detach().cuda(gpu_id_1) pred1_D = F.softmax((pred1), dim=1) pred2_D = F.softmax((pred2), dim=1) pred1_last_D = F.softmax((pred1_last), dim=1) pred2_last_D = F.softmax((pred2_last), dim=1) pred1_target_D = F.softmax((pred1_target), dim=1) pred2_target_D = F.softmax((pred2_target), dim=1) real1_D = torch.cat((pred1_D, pred1_last_D), dim=1) real2_D = torch.cat((pred2_D, pred2_last_D), dim=1) mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1) mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1) D_out1_real = model_D1(real1_D) D_out2_real = model_D1(real2_D) D_out1_mix = model_D1(mix1_D_) D_out2_mix = model_D1(mix2_D_) loss_D1 = mse_loss( D_out1_real, Variable( torch.FloatTensor(D_out1_real.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_D2 = mse_loss( D_out2_real, Variable( torch.FloatTensor(D_out2_real.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_D3 = mse_loss( D_out1_mix, Variable( torch.FloatTensor(D_out1_mix.data.size()).fill_( mix_label)).cuda(gpu_id_1)) loss_D4 = mse_loss( D_out2_mix, Variable( torch.FloatTensor(D_out2_mix.data.size()).fill_( mix_label)).cuda(gpu_id_1)) loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2 loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value1 += loss_D2.data.cpu().numpy() # train with target pred1 = pred1.detach().cuda(gpu_id_1) pred2 = pred2.detach().cuda(gpu_id_1) pred1_target = pred1_target.detach().cuda(gpu_id_1) pred2_target = pred2_target.detach().cuda(gpu_id_1) pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1) pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1) pred1_D = F.softmax((pred1), dim=1) pred2_D = F.softmax((pred2), dim=1) pred1_last_target_D = F.softmax((pred1_last_target), dim=1) pred2_last_target_D = F.softmax((pred2_last_target), dim=1) pred1_target_D = F.softmax((pred1_target), dim=1) pred2_target_D = F.softmax((pred2_target), dim=1) fake1_D_ = torch.cat((pred1_last_target_D, pred1_target_D), dim=1) fake2_D_ = torch.cat((pred2_last_target_D, pred2_target_D), dim=1) mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1) mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1) D_out1 = model_D1(fake1_D_) D_out2 = model_D1(fake2_D_) D_out3 = model_D1(mix1_D__) D_out4 = model_D1(mix2_D__) loss_D1 = mse_loss( D_out1, Variable( torch.FloatTensor(D_out1.data.size()).fill_( target_label)).cuda(gpu_id_1)) loss_D2 = mse_loss( D_out2, Variable( torch.FloatTensor(D_out2.data.size()).fill_( target_label)).cuda(gpu_id_1)) loss_D3 = mse_loss( D_out3, Variable( torch.FloatTensor( D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1)) loss_D4 = mse_loss( D_out4, Variable( torch.FloatTensor( D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1)) loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2 loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2 loss_D1.backward() loss_D2.backward() batch_last, batch_last_target = batch, batch_target loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value1 += loss_D2.data.cpu().numpy() # train model-D2 pred1 = pred1.detach().cuda(gpu_id_1) pred2 = pred2.detach().cuda(gpu_id_1) pred1_target = pred1_target.detach().cuda(gpu_id_1) pred2_target = pred2_target.detach().cuda(gpu_id_1) pred1_D = F.softmax((pred1), dim=1) pred2_D = F.softmax((pred2), dim=1) pred1_target_D = F.softmax((pred1_target), dim=1) pred2_target_D = F.softmax((pred2_target), dim=1) D_out1 = model_D2(pred1_D) D_out2 = model_D2(pred2_D) D_out3 = model_D2(pred1_target_D) D_out4 = model_D2(pred2_target_D) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor(D_out1.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor(D_out2.data.size()).fill_( source_label)).cuda(gpu_id_1)) loss_D3 = bce_loss( D_out3, Variable( torch.FloatTensor( D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1)) loss_D4 = bce_loss( D_out4, Variable( torch.FloatTensor( D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1)) loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2 loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2 loss_D1.backward() loss_D2.backward() batch_last, batch_last_target = batch, batch_target loss_D_value2 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, number1 = {8}, number2 = {9}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, number1, number2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main(args): """Create the model and start the training.""" mkdir_check(args.snapshot_dir) writer = SummaryWriter(log_dir=os.path.join(args.snapshot_dir, 'tb-logs')) h, w = map(int, args.src_input_size.split(',')) src_input_size = (h, w) h, w = map(int, args.tgt_input_size.split(',')) tgt_input_size = (h, w) cudnn.enabled = True gpu = args.gpu # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D2 = FCDiscriminator(num_classes=19) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader( ListDataSet(args.src_data_dir, args.src_img_list, args.src_lbl_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=src_input_size, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader( ListDataSet(args.tgt_data_dir, args.tgt_img_list, args.tgt_lbl_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_nolabel = data.DataLoader( ListDataSet(args.tgt_data_dir, args.tgt_img_nolabel_list, None, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) targetloader_nolabel_iter = enumerate(targetloader_nolabel) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() interp = nn.Upsample(size=(src_input_size[1], src_input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(tgt_input_size[1], tgt_input_size[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(args.num_steps): loss_seg_value2 = 0 loss_tgt_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(args, optimizer, i_iter) optimizer_D2.zero_grad() adjust_learning_rate_D(args, optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) #pred1 = interp(pred1) pred2 = interp(pred2) #loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss = loss_seg2 #+ args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size # train with target seg _, batch = targetloader_iter.__next__() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) pred_target1, pred_target2 = model(images) #pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) #loss_tgt_seg1 = loss_calc(pred_target1, labels, args.gpu) loss_tgt_seg2 = loss_calc(pred_target2, labels, args.gpu) loss = loss_tgt_seg2 #+ args.lambda_seg * loss_tgt_seg1 # proper normalization loss = loss / args.iter_size loss.backward(retain_graph=True) loss_tgt_seg_value2 += loss_tgt_seg2.data.cpu().numpy() / args.iter_size # train with target_nolabel adv _, batch = targetloader_nolabel_iter.__next__() images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred_target1, pred_target2 = model(images) #pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out2 = model_D2(F.softmax(pred_target2, dim=-1)) loss_adv_target2 = bce_loss(D_out2, Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss = args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size # train D # bring back requires_grad for param in model_D2.parameters(): param.requires_grad = True # train with source pred2 = pred2.detach() D_out2 = model_D2(F.softmax(pred2, dim=-1)) loss_D2 = bce_loss(D_out2, Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D2 = loss_D2 / args.iter_size / 2 loss_D2.backward() loss_D_value2 += loss_D2.data.cpu().numpy() # train with target pred_target2 = pred_target2.detach() D_out2 = model_D2(F.softmax(pred_target2, dim=-1)) loss_D2 = bce_loss(D_out2, Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D2 = loss_D2 / args.iter_size / 2 loss_D2.backward() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D2.step() print( 'iter = {:5d}/{:8d}, loss_seg2 = {:.3f} loss_tgt_seg2 = {:.3f} loss_adv2 = {:.3f} loss_D2 = {:.3f}'.format( i_iter, args.num_steps_stop, loss_seg_value2, loss_tgt_seg_value2, loss_adv_target_value2, loss_D_value2)) writer.add_scalars('loss/seg', { 'src2': loss_seg_value2, 'tgt2': loss_tgt_seg_value2, }, i_iter) writer.add_scalar('loss/adv', loss_adv_target_value2, i_iter) writer.add_scalar('loss/d', loss_D_value2, i_iter) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter))) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter))) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter))) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
def main(): """Create the model and start the training.""" setup_seed(666) device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network model = DeeplabMulti(num_classes=args.num_classes) saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not args.num_classes == 19 or not i_parts[1] == 'layer5': if i_parts[1]=='layer4' and i_parts[2]=='2': i_parts[1] = 'layer5' i_parts[2] = '0' new_params['.'.join(i_parts[1:])] = saved_state_dict[i] else: new_params['.'.join(i_parts[1:])] = saved_state_dict[i] model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # init D num_class_list = [2048, 19] model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(device) if i<1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in range(2)]) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader( GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(args.num_steps): optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) feat_source, pred_source = model(images, model_D, 'source') pred_source = interp(pred_source) loss_seg = seg_loss(pred_source, labels) loss_seg.backward() # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) feat_target, pred_target = model(images, model_D, 'target') pred_target = interp_target(pred_target) loss_adv = 0 D_out = model_D[0](feat_target) loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device)) D_out = model_D[1](F.softmax(pred_target, dim=1)) loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device)) loss_adv = loss_adv*0.01 loss_adv.backward() optimizer.step() # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with source loss_D_source = 0 D_out_source = model_D[0](feat_source.detach()) loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device)) D_out_source = model_D[1](F.softmax(pred_source.detach(),dim=1)) loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device)) loss_D_source.backward() # train with target loss_D_target = 0 D_out_target = model_D[0](feat_target.detach()) loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device)) D_out_target = model_D[1](F.softmax(pred_target.detach(),dim=1)) loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device)) loss_D_target.backward() optimizer_D.step() if i_iter % 10 == 0: print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D_s = {4:.3f}, loss_D_t = {5:.3f}'.format( i_iter, args.num_steps, loss_seg.item(), loss_adv.item(), loss_D_source.item(), loss_D_target.item())) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
def main(): """Create the model and start the training.""" if RESTART: args.snapshot_dir = RESTART_FROM else: args.snapshot_dir = generate_snapshot_name(args) args_dict = vars(args) import json ###### load args for restart ###### if RESTART: # pdb.set_trace() args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format( RESTART_ITER) with open(args_dict_file) as f: args_dict_last = json.load(f) for arg in args_dict: args_dict[arg] = args_dict_last[arg] ###### load args for restart ###### device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True cudnn.benchmark = True if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device) #### restore model_D and model if RESTART: # pdb.set_trace() # model parameters restart_from_model = args.restart_from + 'GTA5_{}.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_model) model.load_state_dict(saved_state_dict) # model_D parameters restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_D) model_D.load_state_dict(saved_state_dict) #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet else: # model parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) model_D.train() model_D.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() """ optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() """ if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): # pdb.set_trace() loss_seg_value1 = 0 loss_seg_value2 = 0 adv_loss_value = 0 d_loss_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate(optimizer_D, i_iter) """ optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) """ for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False """ for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False """ # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) # pdb.set_trace() # images.size() == [1, 3, 720, 1280] pred1, pred2 = model(images) # pred1, pred2 size == [1, 19, 91, 161] pred1 = interp(pred1) pred2 = interp(pred2) # size (1, 19, 720, 1280) # pdb.set_trace() # feature = nn.Softmax(dim=1)(pred1) # softmax_out = nn.Softmax(dim=1)(pred2) loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # pdb.set_trace() # proper normalization loss = loss / args.iter_size # TODO: uncomment loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # pdb.set_trace() # train with target _, batch = targetloader_iter.__next__() for params in model_D.parameters(): params.requires_grad_(requires_grad=False) images, _, _ = batch images = images.to(device) # pdb.set_trace() # images.size() == [1, 3, 720, 1280] pred_target1, pred_target2 = model(images) # pred_target1, 2 == [1, 19, 91, 161] pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) # pred_target1, 2 == [1, 19, 720, 1280] # pdb.set_trace() # feature_target = nn.Softmax(dim=1)(pred_target1) # softmax_out_target = nn.Softmax(dim=1)(pred_target2) # features = torch.cat((pred1, pred_target1), dim=0) # outputs = torch.cat((pred2, pred_target2), dim=0) # features.size() == [2, 19, 720, 1280] # softmax_out.size() == [2, 19, 720, 1280] # pdb.set_trace() # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None) D_out_target = CDAN( [F.softmax(pred_target1), F.softmax(pred_target2)], model_D, cdan_implement='concat') dc_source = torch.FloatTensor( D_out_target.size()).fill_(0).to(device) # pdb.set_trace() adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source) adv_loss = adv_loss / args.iter_size adv_loss = args.lambda_adv * adv_loss # pdb.set_trace() # classifier_loss = nn.BCEWithLogitsLoss()(pred2, # torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda()) # pdb.set_trace() adv_loss.backward() adv_loss_value += adv_loss.item() # optimizer_D.step() #TODO: normalize loss? for params in model_D.parameters(): params.requires_grad_(requires_grad=True) pred1 = pred1.detach() pred2 = pred2.detach() D_out = CDAN([F.softmax(pred1), F.softmax(pred2)], model_D, cdan_implement='concat') dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device) # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None) d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source) d_loss = d_loss / args.iter_size # pdb.set_trace() d_loss.backward() d_loss_value += d_loss.item() pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out_target = CDAN( [F.softmax(pred_target1), F.softmax(pred_target2)], model_D, cdan_implement='concat') dc_target = torch.FloatTensor( D_out_target.size()).fill_(1).to(device) d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target) d_loss = d_loss / args.iter_size # pdb.set_trace() d_loss.backward() d_loss_value += d_loss.item() continue optimizer.step() optimizer_D.step() scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'generator_loss': adv_loss_value, 'discriminator_loss': d_loss_value, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) # pdb.set_trace() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, adv_loss_value, d_loss_value)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) # check_original_discriminator(args, pred_target1, pred_target2, i_iter) save_path = args.snapshot_dir + '/eval_{}'.format(i_iter) if not os.path.exists(save_path): os.makedirs(save_path) # evaluate(args, save_path, args.snapshot_dir, i_iter) ###### also record latest saved iteration ####### args_dict['learning_rate'] = optimizer.param_groups[0]['lr'] args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr'] args_dict['start_steps'] = i_iter args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format( i_iter) pdb.set_trace() with open(args_dict_file, 'w') as f: json.dump(args_dict, f) ###### also record latest saved iteration ####### writer.close()
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = args.gpu # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear') # labels for adversarial training source_label = 0 target_label = 1 def _forward_single(_iter): _scalar = torch.tensor(1.).cuda().requires_grad_() _, batch = _iter.__next__() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) pred1, pred2 = model.forward_irm(images, _scalar) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size # compute penalty grad = autograd.grad(loss, [_scalar], create_graph=True)[0] penalty = torch.sum(grad**2) loss += args.lambda_irm * penalty loss.backward() return loss_seg1, loss_seg2 for i_iter in range(args.num_steps): loss_seg1_src = 0 loss_seg2_src = 0 loss_seg1_tgt = 0 loss_seg2_tgt = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) for sub_i in range(args.iter_size): # def _forward_single(_loader, _model) # _, batch = _loader_iter.next() # images, labels, _, _ = batch # images = Variable(images).cuda(args.gpu) # pred1, pred2 = _model(images) # pred1 = interp(pred1) # pred2 = interp(pred2) # loss_seg1 = loss_calc(pred1, labels, args.gpu) # loss_seg2 = loss_calc(pred2, labels, args.gpu) # loss = loss_seg2 + args.lambda_seg * loss_seg1 # # proper normalization # loss = loss / args.iter_size # loss.backward() # return loss_seg1, loss_seg2 # _, batch = targetloader_iter.next() # images, _, _ = batch # images = Variable(images).cuda(args.gpu) # pred1_tgt, pred2_tgt = model(images) # pred1_tgt = interp_target(pred1_tgt) # pred2_tgt = interp_target(pred2_tgt) # loss_seg1 = loss_calc(pred1, labels, args.gpu) # loss_seg2 = loss_calc(pred2, labels, args.gpu) # loss = loss_seg2 + args.lambda_seg * loss_seg1 # loss = loss / args.iter_size # loss.backward() # train with source loss_seg1, loss_seg2 = _forward_single(trainloader_iter) loss_seg1_src += loss_seg1.data.cpu().numpy()[0] / args.iter_size loss_seg2_src += loss_seg2.data.cpu().numpy()[0] / args.iter_size # train with target loss_seg1, loss_seg2 = _forward_single(targetloader_iter) loss_seg1_tgt += loss_seg1.data.cpu().numpy()[0] / args.iter_size loss_seg2_tgt += loss_seg2.data.cpu().numpy()[0] / args.iter_size optimizer.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) #model.load_state_dict(saved_state_dict) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) # model_D1.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D1.pth')) # model_D2.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D2.pth')) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # Load VGG #vgg19 = torchvision.models.vgg19(pretrained=True) #vgg19.to(device) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(0, args.num_steps): loss_seg_value = 0 # loss_seg_local_value = 0 loss_adv_target_value = 0 # loss_adv_local_value = 0 loss_D_value = 0 # loss_D_local_value = 0 loss_local_match_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) pred_s1, pred_s2, _, _ = model(images) #f_s2 = normalize(f_s2) pred_s1, pred_s2 = interp(pred_s1), interp(pred_s2) loss_seg = args.lambda_seg * seg_loss(pred_s1, labels) + seg_loss( pred_s2, labels) del labels # proper normalization loss_seg_value += loss_seg.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pred_t1, pred_t2, _, _ = model(images) #f_t2 = normalize(f_t2) pred_t1, pred_t2 = interp_target(pred_t1), interp_target(pred_t2) del images D_out_1 = model_D1(F.softmax(pred_t1, dim=1)) D_out_2 = model_D2(F.softmax(pred_t2, dim=1)) loss_adv_target1 = bce_loss( D_out_1, torch.FloatTensor( D_out_1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out_2, torch.FloatTensor( D_out_2.data.size()).fill_(source_label).to(device)) loss_adv_target_value += ( args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2).item() / args.iter_size loss = loss_seg + args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2 del D_out_1, D_out_2 # #< Local patch part># # corres_id2 = get_correspondance(f_s2, f_t2, pred_s2, pred_t2) # #loss_local1 = local_feature_loss(corres_id1, f_s1, f_t1, model, seg_loss) # loss_local2 = local_feature_loss(corres_id2, labels, f_t2, model, seg_loss) # loss_local = args.lambda_match_target2 * loss_local2 #+args.lambda_match_target1 * loss_local1 # loss += loss_local # if corres_id2.nelement() > 0: # loss_local_match_value += loss_local.item()/ args.iter_size loss /= args.iter_size loss.backward() # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach() D_out_1, D_out_2 = model_D1(F.softmax(pred_s1)), model_D2( F.softmax(pred_s2)) loss_D_1 = bce_loss( D_out_1, torch.FloatTensor(D_out_1.data.size()).fill_(source_label).to( device)) / args.iter_size / 2 loss_D_2 = bce_loss( D_out_2, torch.FloatTensor(D_out_2.data.size()).fill_(source_label).to( device)) / args.iter_size / 2 loss_D_1.backward() loss_D_2.backward() loss_D_value += (loss_D_1 + loss_D_2).item() # train with target pred_t1, pred_t2 = pred_t1.detach(), pred_t2.detach() D_out_1, D_out_2 = model_D1(F.softmax(pred_t1)), model_D2( F.softmax(pred_t2)) loss_D_1 = bce_loss( D_out_1, torch.FloatTensor(D_out_1.data.size()).fill_(target_label).to( device)) / args.iter_size / 2 loss_D_2 = bce_loss( D_out_2, torch.FloatTensor(D_out_2.data.size()).fill_(target_label).to( device)) / args.iter_size / 2 loss_D_1.backward() loss_D_2.backward() loss_D_value += (loss_D_1 + loss_D_2).item() optimizer.step() optimizer_D1.step() optimizer_D2.step() if i_iter % 1000 == 0: val_dir = '../dataset/Cityscapes/leftImg8bit_trainvaltest/' val_list = './dataset/cityscapes_list/val.txt' save_dir = './results/tmp' gt_dir = '../dataset/Cityscapes/gtFine_trainvaltest/gtFine/val' evaluate_cityscapes.test_model(model, device, val_dir, val_list, save_dir) mIoU = compute_iou.mIoUforTest(gt_dir, save_dir) if args.tensorboard: scalar_info = { 'loss_seg': loss_seg_value, #'loss_seg_local': loss_seg_local_value, 'loss_adv_target': loss_adv_target_value, 'loss_local_match': loss_local_match_value, 'loss_D': loss_D_value, 'mIoU': mIoU #'loss_D_local': loss_D_local_value } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f}, loss_local_match = {5:.3f}, mIoU = {6:3f} ' #'loss_seg_local = {5:.3f} loss_adv_local = {6:.3f}, loss_D_local = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_local_match_value, mIoU) #loss_seg_local_value, loss_adv_local_value, loss_D_local_value) ) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" or_nyu_dict = { 0: 255, 1: 16, 2: 40, 3: 39, 4: 7, 5: 14, 6: 39, 7: 12, 8: 38, 9: 40, 10: 10, 11: 6, 12: 40, 13: 39, 14: 39, 15: 40, 16: 18, 17: 40, 18: 4, 19: 40, 20: 40, 21: 5, 22: 40, 23: 40, 24: 30, 25: 36, 26: 38, 27: 40, 28: 3, 29: 40, 30: 40, 31: 9, 32: 38, 33: 40, 34: 40, 35: 40, 36: 34, 37: 37, 38: 40, 39: 40, 40: 39, 41: 8, 42: 3, 43: 1, 44: 2, 45: 22 } or_nyu_map = lambda x: or_nyu_dict.get(x, x) - 1 or_nyu_map = np.vectorize(or_nyu_map) device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True args.or_nyu_map = or_nyu_map # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) elif args.restore_from == "": saved_state_dict = None else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 40 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True if args.mode != "baseline" and args.mode != "baseline_tar": # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) scale_min = 0.5 scale_max = 2.0 rotate_min = -10 rotate_max = 10 ignore_label = 255 value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] args.width = w args.height = h train_transform = transforms.Compose([ # et.ExtResize( 512 ), transforms.RandScale([scale_min, scale_max]), transforms.RandRotate([rotate_min, rotate_max], padding=IMG_MEAN_RGB, ignore_label=ignore_label), transforms.RandomGaussianBlur(), transforms.RandomHorizontalFlip(), transforms.Crop([args.height + 1, args.width + 1], crop_type='rand', padding=IMG_MEAN_RGB, ignore_label=ignore_label), #et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), #et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), #et.ExtRandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]), ]) val_transform = transforms.Compose([ # et.ExtResize( 512 ), transforms.Crop([args.height + 1, args.width + 1], crop_type='center', padding=IMG_MEAN_RGB, ignore_label=ignore_label), transforms.ToTensor(), transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]), ]) if args.mode != "baseline_tar": src_train_dst = OpenRoomsSegmentation(root=args.data_dir, opt=args, split='train', transform=train_transform, imWidth=args.width, imHeight=args.height, remap_labels=args.or_nyu_map) else: src_train_dst = NYU_Labelled(root=args.data_dir_target, opt=args, split='train', transform=train_transform, imWidth=args.width, imHeight=args.height, phase="TRAIN", randomize=True) tar_train_dst = NYU(root=args.data_dir_target, opt=args, split='train', transform=train_transform, imWidth=args.width, imHeight=args.height, phase="TRAIN", randomize=True, mode=args.mode) tar_val_dst = NYU(root=args.data_dir, opt=args, split='val', transform=val_transform, imWidth=args.width, imHeight=args.height, phase="TRAIN", randomize=False) trainloader = data.DataLoader(src_train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(tar_train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() if args.mode != "baseline" and args.mode != "baseline_tar": optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1] + 1, input_size[0] + 1), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1] + 1, input_size_target[0] + 1), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_seg_value1_tar = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_seg_value2_tar = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) if args.mode != "baseline" and args.mode != "baseline_tar": optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) sample_src = None sample_tar = None sample_res_src = None sample_res_tar = None sample_gt_src = None sample_gt_tar = None for sub_i in range(args.iter_size): # train G if args.mode != "baseline" and args.mode != "baseline_tar": # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.__next__() images, labels, _ = batch sample_src = images.clone() sample_gt_src = labels.clone() images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) sample_pred_src = pred2.detach().cpu() loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target try: _, batch = targetloader_iter.__next__() except: targetloader_iter = enumerate(targetloader) _, batch = targetloader_iter.__next__() images, tar_labels, _, labelled = batch n_labelled = labelled.sum().detach().item() batch_size = images.shape[0] sample_tar = images.clone() sample_gt_tar = tar_labels.clone() images = images.to(device) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) #print("N_labelled {}".format(n_labelled)) if args.mode == "sda" and n_labelled != 0: labelled = labelled.to(device) == 1 tar_labels = tar_labels.to(device) loss_seg1_tar = seg_loss(pred_target1[labelled], tar_labels[labelled]) loss_seg2_tar = seg_loss(pred_target2[labelled], tar_labels[labelled]) loss_tar_labelled = loss_seg2_tar + args.lambda_seg * loss_seg1_tar loss_tar_labelled = loss_tar_labelled / args.iter_size loss_seg_value1_tar += loss_seg1_tar.item() / args.iter_size loss_seg_value2_tar += loss_seg2_tar.item() / args.iter_size else: loss_tar_labelled = torch.zeros( 1, requires_grad=True).float().to(device) # proper normalization sample_pred_tar = pred_target2.detach().cpu() if args.mode != "baseline" and args.mode != "baseline_tar": D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size + loss_tar_labelled #loss = loss_tar_labelled loss.backward() loss_adv_target_value1 += loss_adv_target1.item( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.item( ) / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() if args.mode != "baseline" and args.mode != "baseline_tar": optimizer_D1.step() optimizer_D2.step() if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, 'loss_seg1_tar': loss_seg_value1_tar, 'loss_seg2_tar': loss_seg_value2_tar, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) if i_iter % 1000 == 0: img = sample_src.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy( np.array(IMG_MEAN_RGB).reshape(1, 3, 1, 1)).float() img = img.type(torch.uint8) writer.add_images("Src/Images", img, i_iter) label = tar_train_dst.decode_target(sample_gt_src).transpose( 0, 3, 1, 2) writer.add_images("Src/Labels", label, i_iter) preds = sample_pred_src.permute(0, 2, 3, 1).cpu().numpy() preds = np.asarray(np.argmax(preds, axis=3), dtype=np.uint8) preds = tar_train_dst.decode_target(preds).transpose( 0, 3, 1, 2) writer.add_images("Src/Preds", preds, i_iter) tar_img = sample_tar.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy( np.array(IMG_MEAN_RGB).reshape( 1, 3, 1, 1)).float() tar_img = tar_img.type(torch.uint8) writer.add_images("Tar/Images", tar_img, i_iter) tar_label = tar_train_dst.decode_target( sample_gt_tar).transpose(0, 3, 1, 2) writer.add_images("Tar/Labels", tar_label, i_iter) tar_preds = sample_pred_tar.permute(0, 2, 3, 1).cpu().numpy() tar_preds = np.asarray(np.argmax(tar_preds, axis=3), dtype=np.uint8) tar_preds = tar_train_dst.decode_target(tar_preds).transpose( 0, 3, 1, 2) writer.add_images("Tar/Preds", tar_preds, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_seg1_tar={8:.3f} loss_seg2_tar={9:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, loss_seg_value1_tar, loss_seg_value2_tar)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(args.num_steps_stop) + '.pth')) if args.mode != "baseline" and args.mode != "baseline_tar": torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '.pth')) if args.mode != "baseline" and args.mode != "baseline_tar": torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" if RESTART: args.snapshot_dir = RESTART_FROM else: args.snapshot_dir = generate_snapshot_name(args) args_dict = vars(args) import json ###### load args for restart ###### if RESTART: # pdb.set_trace() args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format( RESTART_ITER) with open(args_dict_file) as f: args_dict_last = json.load(f) for arg in args_dict: args_dict[arg] = args_dict_last[arg] ###### load args for restart ###### device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True cudnn.benchmark = True if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) #### restore model_D1, D2 and model if RESTART: # pdb.set_trace() # model parameters restart_from_model = args.restart_from + 'GTA5_{}.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_model) model.load_state_dict(saved_state_dict) # model_D1 parameters restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_D1) model_D1.load_state_dict(saved_state_dict) # model_D2 parameters restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_D2) model_D2.load_state_dict(saved_state_dict) #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet else: # model parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) #### From here, code should not be related to model reload #### # but we would need hyperparameters: n_iter, # [lr, momentum, weight_decay, betas](these are all in args) # args.snapshot_dir = generate_snapshot_name() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # pdb.set_trace() targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.start_steps, args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) pdb.set_trace() loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pdb.set_trace() pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) pdb.set_trace() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() optimizer_D1.step() optimizer_D2.step() scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) ###### also record latest saved iteration ####### args_dict['learning_rate'] = optimizer.param_groups[0]['lr'] args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr'] args_dict['start_steps'] = i_iter args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format( i_iter) with open(args_dict_file, 'w') as f: json.dump(args_dict, f) ###### also record latest saved iteration ####### writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'ResNet': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) if args.model == 'VGG': model = DeeplabVGG(num_classes=args.num_classes, vgg16_caffe_path='./model/vgg16_init.pth', pretrained=True) model.train() model.to(device) cudnn.benchmark = True # init D if args.model == 'ResNet': model_D = FCDiscriminator(num_classes=2048).to(device) if args.model == 'VGG': model_D = FCDiscriminator(num_classes=1024).to(device) model_D.train() model_D.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg = 0 loss_adv_target_value = 0 loss_D_value = 0 loss_cla_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) feature, prediction = model(images) prediction = interp(prediction) loss = seg_loss(prediction, labels) loss.backward() loss_seg = loss.item() # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) feature_target, _ = model(images) _, D_out = model_D(feature_target) loss_adv_target = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) #print(args.lambda_adv_target) loss = args.lambda_adv_target * loss_adv_target loss.backward() loss_adv_target_value = loss_adv_target.item() # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with source feature = feature.detach() cla, D_out = model_D(feature) cla = interp(cla) loss_cla = seg_loss(cla, labels) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) loss_D = loss_D / 2 #print(args.lambda_s) loss_Disc = args.lambda_s * loss_cla + loss_D loss_Disc.backward() loss_cla_value = loss_cla.item() loss_D_value = loss_D.item() # train with target feature_target = feature_target.detach() _, D_out = model_D(feature_target) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(target_label).to(device)) loss_D = loss_D / 2 loss_D.backward() loss_D_value += loss_D.item() optimizer.step() optimizer_D.step() if args.tensorboard: scalar_info = { 'loss_seg': loss_seg, 'loss_cla': loss_cla_value, 'loss_adv_target': loss_adv_target_value, 'loss_D': loss_D_value, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) #print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}' .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value, loss_D_value, loss_cla_value)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() optimizer_D1.step() optimizer_D2.step() if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = args.gpu # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if args.num_classes != 19 and i_parts[1] != 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # if args.num_classes !=19: # print i_parts model.load_state_dict(new_params) start_time = datetime.datetime.now().strftime('%m-%d_%H-%M') writer_dir = os.path.join("./logs/", args.name, start_time) writer = tensorboard.SummaryWriter(writer_dir) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda(args.gpu) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader( MulitviewSegLoader( num_classes=args.num_classes, root=args.data_dir, number_views=2, view_idx=1, # max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, img_mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = iter(data_loader_cycle(trainloader)) targetloader = data.DataLoader( MulitviewSegLoader( root=args.data_dir_target, num_classes=args.num_classes, number_views=1, view_idx=0, # max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size_target, # scale=False, # mirror=args.random_mirror, img_mean=IMG_MEAN, # set=args.set ), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear') def mdl_val_func(x): return interp_target(model(x)[1]) targetloader_iter = iter(data_loader_cycle(targetloader)) val_loader = data.DataLoader( MulitviewSegLoader( root=args.data_dir_val, number_views=1, view_idx=0, num_classes=args.num_classes, # max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size_target, # scale=False, # mirror=args.random_mirror, img_mean=IMG_MEAN, # set=args.set ), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) criterion = CrossEntropyLoss2d().cuda(args.gpu) valhelper = ValHelper(gpu=args.gpu, model=mdl_val_func, val_loader=val_loader, loss=criterion, writer=writer) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source batch = next(trainloader_iter) images, labels, *_ = batch images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion) loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size # train with target batch = next(targetloader_iter) images, *_ = batch images = Variable(images).cuda(args.gpu) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(target_label)).cuda( args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda( args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D1.step() optimizer_D2.step() if i_iter % args.val_steps == 0 and i_iter: model.eval() log = valhelper.valid_epoch(i_iter) print('log: {}'.format(log)) model.train() if i_iter % 10 == 0 and i_iter: print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) writer.add_scalar(f'train/loss_seg_value1', loss_seg_value1, i_iter) writer.add_scalar(f'train/loss_seg_value2', loss_seg_value2, i_iter) writer.add_scalar(f'train/loss_adv_target_value1', loss_adv_target_value1, i_iter) writer.add_scalar(f'train/loss_D_value1', loss_D_value1, i_iter) writer.add_scalar(f'train/loss_D_value2', loss_D_value2, i_iter) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))