def main(): # create the model model = build_model() model.to(device) model.load_state_dict(torch.load(args.restore_from)) # create domintor model_D1 = FCDiscriminator(num_classes=1) model_D1.to(device) model_D1.load_state_dict(torch.load(args.D_restore_from)) up = torch.nn.Upsample(scale_factor=32, mode='bilinear') sig = torch.nn.Sigmoid() # labels for adversarial training 两种域的记号 salLabel = 0 edgeLabel = 1 picloader = get_loader(args) correct = 0 tot = 0 for i_iter, data_batch in enumerate(picloader): tot += 2 sal_image, edge_image = data_batch['sal_image'], data_batch[ 'edge_image'] sal_image, edge_image = Variable(sal_image), Variable(edge_image) sal_image, edge_image = sal_image.to(device), edge_image.to(device) sal_pred = model(sal_image) edge_pred = model(edge_image) # test D # for param in model_D1.parameters(): # param.requires_grad = True ss_out = model_D1(sal_pred) se_out = model_D1(edge_pred) if pan(ss_out) == salLabel: correct += 1 if pan(se_out) == edgeLabel: correct += 1 if i_iter % 100 == 0: print('processing %d: %f' % (i_iter, correct / tot)) print(correct / tot)
def main(): ''' Create the model and start the training. ''' # Device 설정 device = torch.device("cuda" if not args.cpu else "cpu") # Source와 Target 모두 1280 * 720으로 resizing w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # 모델 생성 if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http' : # 미리 학습된 weight를 다운로드 saved_state_dict = model_zoo.load_url(args.restore_from) else: # pth 파일을 직접 설정할 경우 saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # Discriminator 생성 model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader( GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _, domainess = batch adw = torch.sqrt(1-domainess).float() adw.requires_grad = False images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = adw*bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = adw*bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() optimizer_D1.step() optimizer_D2.step() if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format( i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'ResNet': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) if args.model == 'VGG': model = DeeplabVGG(num_classes=args.num_classes, vgg16_caffe_path='./model/vgg16_init.pth', pretrained=True) model.train() model.to(device) cudnn.benchmark = True # init D if args.model == 'ResNet': model_D = FCDiscriminator(num_classes=2048).to(device) if args.model == 'VGG': model_D = FCDiscriminator(num_classes=1024).to(device) model_D.train() model_D.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg = 0 loss_adv_target_value = 0 loss_D_value = 0 loss_cla_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) feature, prediction = model(images) prediction = interp(prediction) loss = seg_loss(prediction, labels) loss.backward() loss_seg = loss.item() # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) feature_target, _ = model(images) _, D_out = model_D(feature_target) loss_adv_target = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) #print(args.lambda_adv_target) loss = args.lambda_adv_target * loss_adv_target loss.backward() loss_adv_target_value = loss_adv_target.item() # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with source feature = feature.detach() cla, D_out = model_D(feature) cla = interp(cla) loss_cla = seg_loss(cla, labels) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) loss_D = loss_D / 2 #print(args.lambda_s) loss_Disc = args.lambda_s * loss_cla + loss_D loss_Disc.backward() loss_cla_value = loss_cla.item() loss_D_value = loss_D.item() # train with target feature_target = feature_target.detach() _, D_out = model_D(feature_target) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(target_label).to(device)) loss_D = loss_D / 2 loss_D.backward() loss_D_value += loss_D.item() optimizer.step() optimizer_D.step() if args.tensorboard: scalar_info = { 'loss_seg': loss_seg, 'loss_cla': loss_cla_value, 'loss_adv_target': loss_adv_target_value, 'loss_D': loss_D_value, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) #print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}' .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value, loss_D_value, loss_cla_value)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" global args args = get_arguments() if args.dist: init_dist(args.launcher, backend=args.backend) world_size = 1 rank = 0 if args.dist: rank = dist.get_rank() world_size = dist.get_world_size() device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'Deeplab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from, strict=False) new_params = model.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] model.load_state_dict(new_params) elif args.model == 'DeeplabVGG': model = DeeplabVGG(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict, strict=False) elif args.model == 'DeeplabVGGBN': deeplab_vggbn.BatchNorm = SyncBatchNorm2d model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict, strict=False) del saved_state_dict model.train() model.to(device) if args.dist: broadcast_params(model) if rank == 0: print(model) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) if args.dist: broadcast_params(model_D1) if args.restore_D is not None: D_dict = torch.load(args.restore_D) model_D1.load_state_dict(D_dict, strict=False) del D_dict model_D2.train() model_D2.to(device) if args.dist: broadcast_params(model_D2) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_data = GTA5BDDDataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_sampler = None if args.dist: train_sampler = DistributedSampler(train_data) trainloader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False if train_sampler else True, num_workers=args.num_workers, pin_memory=False, sampler=train_sampler) trainloader_iter = enumerate(cycle(trainloader)) target_data = BDDDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set) target_sampler = None if args.dist: target_sampler = DistributedSampler(target_data) targetloader = data.DataLoader(target_data, batch_size=args.batch_size, shuffle=False if target_sampler else True, num_workers=args.num_workers, pin_memory=False, sampler=target_sampler) targetloader_iter = enumerate(cycle(targetloader)) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard and rank == 0: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) torch.cuda.empty_cache() for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, size, _ = batch images = images.to(device) labels = labels.long().to(device) interp = nn.Upsample(size=(size[1], size[0]), mode='bilinear', align_corners=True) pred1 = model(images) pred1 = interp(pred1) loss_seg1 = seg_loss(pred1, labels) loss = loss_seg1 # proper normalization loss = loss / args.iter_size / world_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size _, batch = targetloader_iter.__next__() # train with target images, _, _ = batch images = images.to(device) pred_target1 = model(images) pred_target1 = interp_target(pred_target1) D_out1 = model_D1(F.softmax(pred_target1)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 loss = loss / args.iter_size / world_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() D_out1 = model_D1(F.softmax(pred1)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 / world_size loss_D1.backward() loss_D_value1 += loss_D1.item() # train with target pred_target1 = pred_target1.detach() D_out1 = model_D1(F.softmax(pred_target1)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 / world_size loss_D1.backward() if args.dist: average_gradients(model) average_gradients(model_D1) average_gradients(model_D2) loss_D_value1 += loss_D1.item() optimizer.step() optimizer_D1.step() if rank == 0: if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1 * world_size, 'loss_D2': loss_D_value2 * world_size, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) print(args.snapshot_dir) if args.tensorboard and rank == 0: writer.close()
def main(): """Create the model and start the training.""" if RESTART: args.snapshot_dir = RESTART_FROM else: args.snapshot_dir = generate_snapshot_name(args) args_dict = vars(args) import json ###### load args for restart ###### if RESTART: # pdb.set_trace() args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format( RESTART_ITER) with open(args_dict_file) as f: args_dict_last = json.load(f) for arg in args_dict: args_dict[arg] = args_dict_last[arg] ###### load args for restart ###### device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True cudnn.benchmark = True if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) #### restore model_D1, D2 and model if RESTART: # pdb.set_trace() # model parameters restart_from_model = args.restart_from + 'GTA5_{}.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_model) model.load_state_dict(saved_state_dict) # model_D1 parameters restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_D1) model_D1.load_state_dict(saved_state_dict) # model_D2 parameters restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_D2) model_D2.load_state_dict(saved_state_dict) #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet else: # model parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) #### From here, code should not be related to model reload #### # but we would need hyperparameters: n_iter, # [lr, momentum, weight_decay, betas](these are all in args) # args.snapshot_dir = generate_snapshot_name() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # pdb.set_trace() targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.start_steps, args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) pdb.set_trace() loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pdb.set_trace() pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) pdb.set_trace() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() optimizer_D1.step() optimizer_D2.step() scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) ###### also record latest saved iteration ####### args_dict['learning_rate'] = optimizer.param_groups[0]['lr'] args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr'] args_dict['start_steps'] = i_iter args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format( i_iter) with open(args_dict_file, 'w') as f: json.dump(args_dict, f) ###### also record latest saved iteration ####### writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") cudnn.benchmark = True cudnn.enabled = True w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) Iter = 0 bestIoU = 0 # Create network # init G if args.model == 'DeepLab': model = DeeplabMultiFeature(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) if args.continue_train: if list(saved_state_dict.keys())[0].split('.')[0] == 'module': for key in saved_state_dict.keys(): saved_state_dict['.'.join( key.split('.')[1:])] = saved_state_dict.pop(key) model.load_state_dict(saved_state_dict) else: new_params = model.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] model.load_state_dict(new_params) # init D model_D = FCDiscriminator(num_classes=args.num_classes).to(device) if args.continue_train: model_weights_path = args.restore_from temp = model_weights_path.split('.') temp[-2] = temp[-2] + '_D' model_D_weights_path = '.'.join(temp) model_D.load_state_dict(torch.load(model_D_weights_path)) temp = model_weights_path.split('.') temp = temp[-2][-9:] Iter = int(temp.split('_')[1]) + 1 model.train() model.to(device) model_D.train() model_D.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # init data loader if args.data_dir.split('/')[-1] == 'gta5_deeplab': trainset = GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) elif args.data_dir.split('/')[-1] == 'syn_deeplab': trainset = synthiaDataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # init optimizer optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=True, loss_scale="dynamic") model_D, optimizer_D = amp.initialize(model_D, optimizer_D, opt_level="O2", keep_batchnorm_fp32=True, loss_scale="dynamic") # init loss bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) L1_loss = torch.nn.L1Loss(reduction='none') interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) test_interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # init prototype num_prototype = args.num_prototype num_ins = args.num_prototype * 10 src_cls_features = torch.zeros([len(BG_LABEL), num_prototype, 2048], dtype=torch.float32).to(device) src_cls_ptr = np.zeros(len(BG_LABEL), dtype=np.uint64) src_ins_features = torch.zeros([len(FG_LABEL), num_ins, 2048], dtype=torch.float32).to(device) src_ins_ptr = np.zeros(len(FG_LABEL), dtype=np.uint64) # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) # start training for i_iter in range(Iter, args.num_steps): loss_seg_value = 0 loss_adv_target_value = 0 loss_D_value = 0 loss_cls_value = 0 loss_ins_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) src_feature, pred = model(images) pred_softmax = F.softmax(pred, dim=1) pred_idx = torch.argmax(pred_softmax, dim=1) right_label = F.interpolate(labels.unsqueeze(0).float(), (pred_idx.size(1), pred_idx.size(2)), mode='nearest').squeeze(0).long() right_label[right_label != pred_idx] = 255 for ii in range(len(BG_LABEL)): cls_idx = BG_LABEL[ii] mask = right_label == cls_idx if torch.sum(mask) == 0: continue feature = global_avg_pool(src_feature, mask.float()) if cls_idx != torch.argmax( torch.squeeze(model.layer6( feature.half()).float())).item(): continue src_cls_features[ii, int(src_cls_ptr[ii] % num_prototype), :] = torch.squeeze( feature).clone().detach() src_cls_ptr[ii] += 1 seg_ins = seg_label(right_label.squeeze()) for ii in range(len(FG_LABEL)): cls_idx = FG_LABEL[ii] segmask, pixelnum = seg_ins[ii] if len(pixelnum) == 0: continue sortmax = np.argsort(pixelnum)[::-1] for i in range(min(10, len(sortmax))): mask = segmask == (sortmax[i] + 1) feature = global_avg_pool(src_feature, mask.float()) if cls_idx != torch.argmax( torch.squeeze( model.layer6(feature.half()).float())).item(): continue src_ins_features[ii, int(src_ins_ptr[ii] % num_ins), :] = torch.squeeze( feature).clone().detach() src_ins_ptr[ii] += 1 pred = interp(pred) loss_seg = seg_loss(pred, labels) loss = loss_seg # proper normalization loss = loss / args.iter_size amp_backward(loss, optimizer) loss_seg_value += loss_seg.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) trg_feature, pred_target = model(images) pred_target_softmax = F.softmax(pred_target, dim=1) pred_target_idx = torch.argmax(pred_target_softmax, dim=1) loss_cls = torch.zeros(1).to(device) loss_ins = torch.zeros(1).to(device) if i_iter > 0: for ii in range(len(BG_LABEL)): cls_idx = BG_LABEL[ii] if src_cls_ptr[ii] / num_prototype <= 1: continue mask = pred_target_idx == cls_idx feature = global_avg_pool(trg_feature, mask.float()) if cls_idx != torch.argmax( torch.squeeze( model.layer6(feature.half()).float())).item(): continue ext_feature = feature.squeeze().expand(num_prototype, 2048) loss_cls += torch.min( torch.sum(L1_loss(ext_feature, src_cls_features[ii, :, :]), dim=1) / 2048.) seg_ins = seg_label(pred_target_idx.squeeze()) for ii in range(len(FG_LABEL)): cls_idx = FG_LABEL[ii] if src_ins_ptr[ii] / num_ins <= 1: continue segmask, pixelnum = seg_ins[ii] if len(pixelnum) == 0: continue sortmax = np.argsort(pixelnum)[::-1] for i in range(min(10, len(sortmax))): mask = segmask == (sortmax[i] + 1) feature = global_avg_pool(trg_feature, mask.float()) feature = feature.squeeze().expand(num_ins, 2048) loss_ins += torch.min( torch.sum(L1_loss(feature, src_ins_features[ii, :, :]), dim=1) / 2048.) / min(10, len(sortmax)) pred_target = interp_target(pred_target) D_out = model_D(F.softmax(pred_target, dim=1)) loss_adv_target = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target * loss_adv_target + args.lambda_adv_cls * loss_cls + args.lambda_adv_ins * loss_ins loss = loss / args.iter_size amp_backward(loss, optimizer) loss_adv_target_value += loss_adv_target.item() / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with source pred = pred.detach() D_out = model_D(F.softmax(pred, dim=1)) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) loss_D = loss_D / args.iter_size / 2 amp_backward(loss_D, optimizer_D) loss_D_value += loss_D.item() # train with target pred_target = pred_target.detach() D_out = model_D(F.softmax(pred_target, dim=1)) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(target_label).to(device)) loss_D = loss_D / args.iter_size / 2 amp_backward(loss_D, optimizer_D) loss_D_value += loss_D.item() optimizer.step() optimizer_D.step() if args.tensorboard: scalar_info = { 'loss_seg': loss_seg_value, 'loss_adv_target': loss_adv_target_value, 'loss_D': loss_D_value, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv = {3:.3f} loss_D = {4:.3f} loss_cls = {5:.3f} loss_ins = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_cls.item(), loss_ins.item())) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') if not os.path.exists(args.save): os.makedirs(args.save) testloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target_test, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set='val'), batch_size=1, shuffle=False, pin_memory=True) model.eval() for index, batch in enumerate(testloader): image, _, name = batch with torch.no_grad(): output1, output2 = model(Variable(image).to(device)) output = test_interp(output2).cpu().data[0].numpy() output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) output = Image.fromarray(output) name = name[0].split('/')[-1] output.save('%s/%s' % (args.save, name)) mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'), args.save, 'dataset/cityscapes_list') mIoU = round(np.nanmean(mIoUs) * 100, 2) if mIoU > bestIoU: bestIoU = mIoU torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'BestGTA5.pth')) torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'BestGTA5_D.pth')) torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) model.train() if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': #model = DeeplabMulti(num_classes=args.num_classes) #model = Res_Deeplab(num_classes=args.num_classes) model = DeepLab(backbone='resnet', output_stride=16) ''' if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) #restore(model, saved_state_dict) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not i_parts[0] == 'layer4' and not i_parts[0] == 'fc': #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] new_params[i] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) ''' else: raise NotImplementedError model.train() model.to(device) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) # if args.restore_from_D[:4] == 'http': # saved_state_dict = model_zoo.load_url(args.restore_from_D) # else: # saved_state_dict = torch.load(args.restore_from_D) # ### for running different versions of pytorch # model_dict = model_D1.state_dict() # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict} # model_dict.update(saved_state_dict) # model_D1.load_state_dict(saved_state_dict) model_D1.train() model_D1.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_loader = data_loader(args) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss() interp = nn.Upsample(size=(416, 416), mode='bilinear', align_corners=True) #interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training # set up tensorboard if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) count = args.start_count # 迭代次数 for dat in train_loader: if count > args.num_steps: break loss_seg_value1_anchor = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, count) optimizer_D1.zero_grad() adjust_learning_rate_D(optimizer_D1, count) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False # 相当于group=0时,训练样本对应的类有15类为[0,1,2,3,4,5,6,7,8,9,10,....],验证集有5类, # 现在从训练集类中随机选择两类,然后从其中一类中选择两张图片,对应为基准图片和正样本图片, # 两者属于同一类,接着从另一类中选择一张图片作为负样本,属于不同类。其中基准图片对应的是查询集图片 ############################# anchor_img, anchor_mask, pos_img, pos_mask, neg_img, neg_mask = dat # 返回的是基准图片以及mask,正样本以及mask(和基准图片属于同一类),负样本以及mask(和基准图片属于不同类) anchor_img, anchor_mask, pos_img, pos_mask, \ = anchor_img.cuda(), anchor_mask.cuda(), pos_img.cuda(), pos_mask.cuda() # [1, 3, 386, 500],[1, 386, 500],[1, 3, 374, 500],[1, 374, 500] anchor_mask = torch.unsqueeze(anchor_mask, dim=1) # [1, 1, 386, 500] pos_mask = torch.unsqueeze(pos_mask, dim=1) # [1,1, 374, 500] samples = torch.cat([pos_img, anchor_img], 0) pred = model(samples, pos_mask) ##[2, 2, 53, 53],#[2, 2, 53, 53] pred = interp(pred) loss_seg1_anchor = seg_loss( pred, anchor_mask.squeeze().unsqueeze(0).long()) D_out1 = model_D1(F.softmax(pred)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to( device)) # 相当于将源域的标签设置为1,然后判断判别网络得到的目标预测与源域对应的损失 ''' s = torch.stack([s, 1-s]) loss_s = seg_loss() ''' loss = loss_seg1_anchor + args.lambda_adv_target1 * loss_adv_target1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1_anchor += loss_seg1_anchor.item() / args.iter_size loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # train D# bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True # train with anchor pred_target1 = pred.detach() D_out1 = model_D1(F.softmax(pred_target1)) loss_D1 = bce_loss( D_out1, torch.FloatTensor(D_out1.data.size()).fill_(0).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D1.backward() loss_D_value1 += loss_D1.item() # train with GT anchor_gt = Variable(one_hot(anchor_mask)).cuda() D_out1 = model_D1(anchor_gt) loss_D1 = bce_loss( D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D1.backward() loss_D_value1 += loss_D1.item() optimizer.step() optimizer_D1.step() count = count + 1 if args.tensorboard: scalar_info = { 'loss_seg1_anchor': loss_seg_value1_anchor, 'loss_adv_target1': loss_adv_target_value1, 'loss_D1': loss_D_value1, } if count % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, count) # print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}' .format(count, args.num_steps, loss_seg_value1_anchor, loss_adv_target_value1, loss_D_value1)) if count >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(args.num_steps_stop) + '_D1.pth')) break if count % args.save_pred_every == 0 and count != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(count) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(count) + '_D1.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': model = DeeplabMultiFeature(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # init D model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) cityset = cityscapesDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set) targetloader = data.DataLoader(cityset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() # init cls D model_clsD = [] optimizer_clsD = [] for i in range(args.num_classes): model_temp = FCDiscriminatorCLS( num_classes=args.num_classes).to(device).train() optimizer_temp = optim.Adam(model_temp.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_temp.zero_grad() #model_temp, optimizer_temp = amp.initialize( # model_temp, optimizer_temp, opt_level="O1", # keep_batchnorm_fp32=None, loss_scale="dynamic" #) model_temp, optimizer_temp = amp.initialize(model_temp, optimizer_temp, opt_level="O1", keep_batchnorm_fp32=None, loss_scale="dynamic") model_clsD.append(model_temp) optimizer_clsD.append(optimizer_temp) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", keep_batchnorm_fp32=None, loss_scale="dynamic") model_D2, optimizer_D2 = amp.initialize(model_D2, optimizer_D2, opt_level="O1", keep_batchnorm_fp32=None, loss_scale="dynamic") bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 cls_begin_iter = 10000 num_target_imgs = 2975 predicted_label = np.zeros( (num_target_imgs, 1, input_size_target[1], input_size_target[0]), dtype=np.uint8) predicted_prob = np.zeros( (num_target_imgs, 1, input_size_target[1], input_size_target[0]), dtype=np.float16) name2idxmap = {} for i in range(num_target_imgs): name2idxmap[cityset.files[i]['name']] = i thres = [] # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 loss_cls_adv = 0 loss_cls_adv_value = 0 loss_cls_D = 0 loss_cls_D_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D2, i_iter) if i_iter >= cls_begin_iter: for i in range(args.num_classes): optimizer_clsD[i].zero_grad() adjust_learning_rate_D(optimizer_clsD[i], i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D2.parameters(): param.requires_grad = False if i_iter >= cls_begin_iter: for i in range(args.num_classes): for param in model_clsD[i].parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) _, pred2 = model(images) pred2 = interp(pred2) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 # proper normalization loss = loss / args.iter_size amp_backward(loss, optimizer) loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, name = batch images = images.to(device) name = name[0] img_idx = name2idxmap[name] _, pred_target2 = model(images) pred_target2 = interp_target(pred_target2) pred_target_score = F.softmax(pred_target2, dim=1) D_out2 = model_D2(pred_target_score) loss_adv_target2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) target_pred_prob, target_pred_cls = torch.max(pred_target_score, dim=1) predicted_label[img_idx, ...] = target_pred_cls.cpu().data.numpy().astype( np.uint8) predicted_prob[img_idx, ...] = target_pred_prob.cpu().data.numpy().astype( np.float16) if i_iter >= cls_begin_iter and i_iter % 5000 == 0: thres = [] for i in range(args.num_classes): x = predicted_prob[predicted_label == i] if len(x) == 0: thres.append(0) continue x = np.sort(x) thres.append(x[np.int(np.round(len(x) * 0.5))]) print(thres) thres = np.array(thres) thres[thres > 0.9] = 0.9 np.save(osp.join(args.snapshot_dir, 'predicted_label'), predicted_label) np.save(osp.join(args.snapshot_dir, 'predicted_prob'), predicted_prob) if i_iter >= cls_begin_iter: target_pred_cls = target_pred_cls.long().detach() for i in range(args.num_classes): cls_mask = (target_pred_cls == i) * (target_pred_cls >= thres[i]) if torch.sum(cls_mask) == 0: continue cls_gt = torch.tensor( target_pred_cls.data).long().to(device) cls_gt[~cls_mask] = 255 cls_gt[cls_mask] = source_label cls_out = model_clsD[i](pred_target_score) loss_cls_adv += seg_loss(cls_out, cls_gt) loss_cls_adv_value = loss_cls_adv.item() / args.iter_size loss = args.lambda_adv_target2 * loss_adv_target2 + LAMBDA_CLS_ADV * loss_cls_adv loss = loss / args.iter_size amp_backward(loss, optimizer) loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size # train D # bring back requires_grad for param in model_D2.parameters(): param.requires_grad = True # train with source pred2 = pred2.detach() D_out2 = model_D2(F.softmax(pred2, dim=1)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss_D2 = loss_D2 / args.iter_size / 2 amp_backward(loss_D2, optimizer_D2) loss_D_value2 += loss_D2.item() # train with target pred_target2 = pred_target2.detach() D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(target_label).to(device)) loss_D2 = loss_D2 / args.iter_size / 2 amp_backward(loss_D2, optimizer_D2) loss_D_value2 += loss_D2.item() if i_iter >= cls_begin_iter: for i in range(args.num_classes): for param in model_clsD[i].parameters(): param.requires_grad = True pred_source_score = F.softmax(pred2, dim=1) source_pred_prob, source_pred_cls = torch.max( pred_source_score, dim=1) source_pred_cls = source_pred_cls.long().detach() for i in range(args.num_classes): cls_mask = (source_pred_cls == i) * (labels == i) if torch.sum(cls_mask) == 0: continue cls_gt = torch.tensor( source_pred_cls.data).long().to(device) cls_gt[~cls_mask] = 255 cls_gt[cls_mask] = source_label cls_out = model_clsD[i](pred_source_score) loss_cls_D = seg_loss(cls_out, cls_gt) / 2 amp_backward(loss_cls_D, optimizer_clsD[i]) loss_cls_D_value += loss_cls_D.item() pred_target_score = F.softmax(pred_target2, dim=1) target_pred_prob, target_pred_cls = torch.max( pred_target_score, dim=1) target_pred_cls = target_pred_cls.long().detach() for i in range(args.num_classes): cls_mask = (target_pred_cls == i) * (target_pred_cls >= thres[i]) if torch.sum(cls_mask) == 0: continue cls_gt = torch.tensor( target_pred_cls.data).long().to(device) cls_gt[~cls_mask] = 255 cls_gt[cls_mask] = target_label cls_out = model_clsD[i](pred_target_score) loss_cls_adv += seg_loss(cls_out, cls_gt) loss_cls_D = seg_loss(cls_out, cls_gt) / 2 amp_backward(loss_cls_D, optimizer_clsD[i]) loss_cls_D_value += loss_cls_D.item() optimizer.step() optimizer_D2.step() if i_iter >= cls_begin_iter: for i in range(args.num_classes): optimizer_clsD[i].step() if args.tensorboard: scalar_info = { 'loss_seg2': loss_seg_value2, 'loss_adv_target2': loss_adv_target_value2, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg2 = {2:.3f}, loss_adv2 = {3:.3f} loss_D2 = {4:.3f} loss_cls_adv = {5:.3f} loss_cls_D = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value2, loss_adv_target_value2, loss_D_value2, loss_cls_adv_value, loss_cls_D_value)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) for i in range(args.num_classes): torch.save( model_clsD[i].state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(NUM_STEPS) + '_clsD.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) for i in range(args.num_classes): torch.save( model_clsD[i].state_dict(), osp.join(args.snapshot_dir, 'GTA5_clsD' + str(i) + '.pth')) if args.tensorboard: writer.close()
net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') net.to(device=device) # faster convolutions, but more memory # cudnn.benchmark = True from torchsummary import summary summary(net, (config.NUM_CHANNELS, config.CROP_H, config.CROP_W)) ####################### # Discriminator ####################### discriminator = FCDiscriminator(num_classes=config.NUM_CLASSES) discriminator.to(device=device) logging.info(f'Discriminator:\n' f'\t{config.NUM_CLASSES} input channels (classes)\n') from torchsummary import summary summary(discriminator, (config.NUM_CLASSES, config.CROP_H, config.CROP_W)) ####################### ####################### try: train_net(net=net, discriminator=discriminator, upsample=upsample, epochs=args.epochs,
def train(log_file, arch, dataset, batch_size, iter_size, num_workers, partial_data, partial_data_size, partial_id, ignore_label, crop_size, eval_crop_size, is_training, learning_rate, learning_rate_d, supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t, semi_start, semi_start_adv, d_remain, momentum, not_restore_last, num_steps, power, random_mirror, random_scale, random_seed, restore_from, restore_from_d, eval_every, save_snapshot_every, snapshot_dir, weight_decay, device): settings = locals().copy() import cv2 import torch import torch.nn as nn from torch.utils import data, model_zoo import numpy as np import pickle import torch.optim as optim import torch.nn.functional as F import scipy.misc import sys import os import os.path as osp import pickle from model.deeplab import Res_Deeplab from model.unet import unet_resnet50 from model.deeplabv3 import resnet101_deeplabv3 from model.discriminator import FCDiscriminator from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d from utils.evaluation import EvaluatorIoU from dataset.voc_dataset import VOCDataSet import logger torch_device = torch.device(device) import time if log_file != '' and log_file != 'none': if os.path.exists(log_file): print('Log file {} already exists; exiting...'.format(log_file)) return with logger.LogFile(log_file if log_file != 'none' else None): if dataset == 'pascal_aug': ds = VOCDataSet(augmented_pascal=True) elif dataset == 'pascal': ds = VOCDataSet(augmented_pascal=False) else: print('Dataset {} not yet supported'.format(dataset)) return print('Command: {}'.format(sys.argv[0])) print('Arguments: {}'.format(' '.join(sys.argv[1:]))) print('Settings: {}'.format(', '.join([ '{}={}'.format(k, settings[k]) for k in sorted(list(settings.keys())) ]))) print('Loaded data') def loss_calc(pred, label): """ This function returns cross entropy loss for semantic segmentation """ # out shape batch_size x channels x h x w -> batch_size x channels x h x w # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w label = label.long().to(torch_device) criterion = CrossEntropy2d() return criterion(pred, label) def lr_poly(base_lr, iter, max_iter, power): return base_lr * ((1 - float(iter) / max_iter)**(power)) def adjust_learning_rate(optimizer, i_iter): lr = lr_poly(learning_rate, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def adjust_learning_rate_D(optimizer, i_iter): lr = lr_poly(learning_rate_d, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def one_hot(label): label = label.numpy() one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1], label.shape[2]), dtype=label.dtype) for i in range(ds.num_classes): one_hot[:, i, ...] = (label == i) #handle ignore labels return torch.tensor(one_hot, dtype=torch.float, device=torch_device) def make_D_label(label, ignore_mask): ignore_mask = np.expand_dims(ignore_mask, axis=1) D_label = np.ones(ignore_mask.shape) * label D_label[ignore_mask] = ignore_label D_label = torch.tensor(D_label, dtype=torch.float, device=torch_device) return D_label h, w = map(int, eval_crop_size.split(',')) eval_crop_size = (h, w) h, w = map(int, crop_size.split(',')) crop_size = (h, w) # create network if arch == 'deeplab2': model = Res_Deeplab(num_classes=ds.num_classes) elif arch == 'unet_resnet50': model = unet_resnet50(num_classes=ds.num_classes) elif arch == 'resnet101_deeplabv3': model = resnet101_deeplabv3(num_classes=ds.num_classes) else: print('Architecture {} not supported'.format(arch)) return # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) model.train() model = model.to(torch_device) # init D model_D = FCDiscriminator(num_classes=ds.num_classes) if restore_from_d is not None: model_D.load_state_dict(torch.load(restore_from_d)) model_D.train() model_D = model_D.to(torch_device) print('Built model') if snapshot_dir is not None: if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir) ds_train_xy = ds.train_xy(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_train_y = ds.train_y(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_val_xy = ds.val_xy(crop_size=eval_crop_size, scale=False, mirror=False, range01=model.RANGE01, mean=model.MEAN, std=model.STD) train_dataset_size = len(ds_train_xy) if partial_data_size != -1: if partial_data_size > partial_data_size: print('partial-data-size > |train|: exiting') return if partial_data == 1.0 and (partial_data_size == -1 or partial_data_size == train_dataset_size): trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_remain = None print('|train|={}'.format(train_dataset_size)) print('|val|={}'.format(len(ds_val_xy))) else: #sample partial data if partial_data_size != -1: partial_size = partial_data_size else: partial_size = int(partial_data * train_dataset_size) if partial_id is not None: train_ids = pickle.load(open(partial_id)) print('loading train ids from {}'.format(partial_id)) else: rng = np.random.RandomState(random_seed) train_ids = list(rng.permutation(train_dataset_size)) if snapshot_dir is not None: pickle.dump(train_ids, open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb')) print('|train supervised|={}'.format(partial_size)) print('|train unsupervised|={}'.format(train_dataset_size - partial_size)) print('|val|={}'.format(len(ds_val_xy))) print('supervised={}'.format(list(train_ids[:partial_size]))) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) testloader = data.DataLoader(ds_val_xy, batch_size=1, shuffle=False, pin_memory=True) print('Data loaders ready') trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(learning_rate), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=learning_rate_d, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() print('Built optimizer') # labels for adversarial training pred_label = 0 gt_label = 1 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_mask_accum = 0 loss_semi_value = 0 loss_semi_adv_value = 0 t1 = time.time() print('Training for {} steps...'.format(num_steps)) for i_iter in range(num_steps + 1): model.train() model.freeze_batchnorm() optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(iter_size): # train G if not supervised: # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \ trainloader_remain is not None: try: _, batch = next(trainloader_remain_iter) except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = next(trainloader_remain_iter) # only access to img images, _, _, _ = batch images = images.float().to(torch_device) pred = model(images) pred_remain = pred.detach() D_out = model_D(F.softmax(pred, dim=1)) D_out_sigmoid = F.sigmoid( D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / iter_size #loss_semi_adv.backward() loss_semi_adv_value += float( loss_semi_adv) / lambda_semi_adv if lambda_semi <= 0 or i_iter < semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < mask_t) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = ignore_label semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size loss_semi_mask_accum += float(semi_ratio) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi / iter_size loss_semi_value += float(loss_semi) / lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = images.float().to(torch_device) ignore_mask = (labels.numpy() == ignore_label) pred = model(images) loss_seg = loss_calc(pred, labels) if supervised: loss = loss_seg else: D_out = model_D(F.softmax(pred, dim=1)) loss_adv_pred = bce_loss( D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + lambda_adv_pred * loss_adv_pred loss_adv_pred_value += float(loss_adv_pred) / iter_size # proper normalization loss = loss / iter_size loss.backward() loss_seg_value += float(loss_seg) / iter_size if not supervised: # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if d_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate( (ignore_mask, ignore_mask_remain), axis=0) D_out = model_D(F.softmax(pred, dim=1)) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = one_hot(labels_gt) ignore_mask_gt = (labels_gt.numpy() == ignore_label) D_out = model_D(D_gt_v) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) optimizer.step() optimizer_D.step() sys.stdout.write('.') sys.stdout.flush() if i_iter % eval_every == 0 and i_iter != 0: model.eval() with torch.no_grad(): evaluator = EvaluatorIoU(ds.num_classes) for index, batch in enumerate(testloader): image, label, size, name = batch size = size[0].numpy() image = image.float().to(torch_device) output = model(image) output = output.cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) evaluator.sample(gt, output, ignore_value=ignore_label) sys.stdout.write('+') sys.stdout.flush() per_class_iou = evaluator.score() mean_iou = per_class_iou.mean() loss_seg_value /= eval_every loss_adv_pred_value /= eval_every loss_D_value /= eval_every loss_semi_mask_accum /= eval_every loss_semi_value /= eval_every loss_semi_adv_value /= eval_every sys.stdout.write('\n') t2 = time.time() print( 'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}' .format(i_iter, num_steps, t2 - t1, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_mask_accum, loss_semi_value, loss_semi_adv_value)) for i, (class_name, iou) in enumerate(zip(ds.class_names, per_class_iou)): print('class {:2d} {:12} IU {:.2f}'.format( i, class_name, iou)) print('meanIOU: ' + str(mean_iou) + '\n') loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_mask_accum = 0 loss_semi_adv_value = 0 t1 = t2 if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) if snapshot_dir is not None: print('save model ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))
def main(): """Create the model and start the training.""" if RESTART: args.snapshot_dir = RESTART_FROM else: args.snapshot_dir = generate_snapshot_name(args) args_dict = vars(args) import json ###### load args for restart ###### if RESTART: # pdb.set_trace() args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format( RESTART_ITER) with open(args_dict_file) as f: args_dict_last = json.load(f) for arg in args_dict: args_dict[arg] = args_dict_last[arg] ###### load args for restart ###### device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True cudnn.benchmark = True if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device) #### restore model_D and model if RESTART: # pdb.set_trace() # model parameters restart_from_model = args.restart_from + 'GTA5_{}.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_model) model.load_state_dict(saved_state_dict) # model_D parameters restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format( RESTART_ITER) saved_state_dict = torch.load(restart_from_D) model_D.load_state_dict(saved_state_dict) #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet else: # model parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) model_D.train() model_D.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() """ optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() """ if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): # pdb.set_trace() loss_seg_value1 = 0 loss_seg_value2 = 0 adv_loss_value = 0 d_loss_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate(optimizer_D, i_iter) """ optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) """ for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False """ for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False """ # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) # pdb.set_trace() # images.size() == [1, 3, 720, 1280] pred1, pred2 = model(images) # pred1, pred2 size == [1, 19, 91, 161] pred1 = interp(pred1) pred2 = interp(pred2) # size (1, 19, 720, 1280) # pdb.set_trace() # feature = nn.Softmax(dim=1)(pred1) # softmax_out = nn.Softmax(dim=1)(pred2) loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # pdb.set_trace() # proper normalization loss = loss / args.iter_size # TODO: uncomment loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # pdb.set_trace() # train with target _, batch = targetloader_iter.__next__() for params in model_D.parameters(): params.requires_grad_(requires_grad=False) images, _, _ = batch images = images.to(device) # pdb.set_trace() # images.size() == [1, 3, 720, 1280] pred_target1, pred_target2 = model(images) # pred_target1, 2 == [1, 19, 91, 161] pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) # pred_target1, 2 == [1, 19, 720, 1280] # pdb.set_trace() # feature_target = nn.Softmax(dim=1)(pred_target1) # softmax_out_target = nn.Softmax(dim=1)(pred_target2) # features = torch.cat((pred1, pred_target1), dim=0) # outputs = torch.cat((pred2, pred_target2), dim=0) # features.size() == [2, 19, 720, 1280] # softmax_out.size() == [2, 19, 720, 1280] # pdb.set_trace() # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None) D_out_target = CDAN( [F.softmax(pred_target1), F.softmax(pred_target2)], model_D, cdan_implement='concat') dc_source = torch.FloatTensor( D_out_target.size()).fill_(0).to(device) # pdb.set_trace() adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source) adv_loss = adv_loss / args.iter_size adv_loss = args.lambda_adv * adv_loss # pdb.set_trace() # classifier_loss = nn.BCEWithLogitsLoss()(pred2, # torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda()) # pdb.set_trace() adv_loss.backward() adv_loss_value += adv_loss.item() # optimizer_D.step() #TODO: normalize loss? for params in model_D.parameters(): params.requires_grad_(requires_grad=True) pred1 = pred1.detach() pred2 = pred2.detach() D_out = CDAN([F.softmax(pred1), F.softmax(pred2)], model_D, cdan_implement='concat') dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device) # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None) d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source) d_loss = d_loss / args.iter_size # pdb.set_trace() d_loss.backward() d_loss_value += d_loss.item() pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out_target = CDAN( [F.softmax(pred_target1), F.softmax(pred_target2)], model_D, cdan_implement='concat') dc_target = torch.FloatTensor( D_out_target.size()).fill_(1).to(device) d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target) d_loss = d_loss / args.iter_size # pdb.set_trace() d_loss.backward() d_loss_value += d_loss.item() continue optimizer.step() optimizer_D.step() scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'generator_loss': adv_loss_value, 'discriminator_loss': d_loss_value, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) # pdb.set_trace() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, adv_loss_value, d_loss_value)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) # check_original_discriminator(args, pred_target1, pred_target2, i_iter) save_path = args.snapshot_dir + '/eval_{}'.format(i_iter) if not os.path.exists(save_path): os.makedirs(save_path) # evaluate(args, save_path, args.snapshot_dir, i_iter) ###### also record latest saved iteration ####### args_dict['learning_rate'] = optimizer.param_groups[0]['lr'] args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr'] args_dict['start_steps'] = i_iter args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format( i_iter) pdb.set_trace() with open(args_dict_file, 'w') as f: json.dump(args_dict, f) ###### also record latest saved iteration ####### writer.close()
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) #model.load_state_dict(saved_state_dict) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) # model_D1.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D1.pth')) # model_D2.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D2.pth')) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # Load VGG #vgg19 = torchvision.models.vgg19(pretrained=True) #vgg19.to(device) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(0, args.num_steps): loss_seg_value = 0 # loss_seg_local_value = 0 loss_adv_target_value = 0 # loss_adv_local_value = 0 loss_D_value = 0 # loss_D_local_value = 0 loss_local_match_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) pred_s1, pred_s2, _, _ = model(images) #f_s2 = normalize(f_s2) pred_s1, pred_s2 = interp(pred_s1), interp(pred_s2) loss_seg = args.lambda_seg * seg_loss(pred_s1, labels) + seg_loss( pred_s2, labels) del labels # proper normalization loss_seg_value += loss_seg.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pred_t1, pred_t2, _, _ = model(images) #f_t2 = normalize(f_t2) pred_t1, pred_t2 = interp_target(pred_t1), interp_target(pred_t2) del images D_out_1 = model_D1(F.softmax(pred_t1, dim=1)) D_out_2 = model_D2(F.softmax(pred_t2, dim=1)) loss_adv_target1 = bce_loss( D_out_1, torch.FloatTensor( D_out_1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out_2, torch.FloatTensor( D_out_2.data.size()).fill_(source_label).to(device)) loss_adv_target_value += ( args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2).item() / args.iter_size loss = loss_seg + args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2 del D_out_1, D_out_2 # #< Local patch part># # corres_id2 = get_correspondance(f_s2, f_t2, pred_s2, pred_t2) # #loss_local1 = local_feature_loss(corres_id1, f_s1, f_t1, model, seg_loss) # loss_local2 = local_feature_loss(corres_id2, labels, f_t2, model, seg_loss) # loss_local = args.lambda_match_target2 * loss_local2 #+args.lambda_match_target1 * loss_local1 # loss += loss_local # if corres_id2.nelement() > 0: # loss_local_match_value += loss_local.item()/ args.iter_size loss /= args.iter_size loss.backward() # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach() D_out_1, D_out_2 = model_D1(F.softmax(pred_s1)), model_D2( F.softmax(pred_s2)) loss_D_1 = bce_loss( D_out_1, torch.FloatTensor(D_out_1.data.size()).fill_(source_label).to( device)) / args.iter_size / 2 loss_D_2 = bce_loss( D_out_2, torch.FloatTensor(D_out_2.data.size()).fill_(source_label).to( device)) / args.iter_size / 2 loss_D_1.backward() loss_D_2.backward() loss_D_value += (loss_D_1 + loss_D_2).item() # train with target pred_t1, pred_t2 = pred_t1.detach(), pred_t2.detach() D_out_1, D_out_2 = model_D1(F.softmax(pred_t1)), model_D2( F.softmax(pred_t2)) loss_D_1 = bce_loss( D_out_1, torch.FloatTensor(D_out_1.data.size()).fill_(target_label).to( device)) / args.iter_size / 2 loss_D_2 = bce_loss( D_out_2, torch.FloatTensor(D_out_2.data.size()).fill_(target_label).to( device)) / args.iter_size / 2 loss_D_1.backward() loss_D_2.backward() loss_D_value += (loss_D_1 + loss_D_2).item() optimizer.step() optimizer_D1.step() optimizer_D2.step() if i_iter % 1000 == 0: val_dir = '../dataset/Cityscapes/leftImg8bit_trainvaltest/' val_list = './dataset/cityscapes_list/val.txt' save_dir = './results/tmp' gt_dir = '../dataset/Cityscapes/gtFine_trainvaltest/gtFine/val' evaluate_cityscapes.test_model(model, device, val_dir, val_list, save_dir) mIoU = compute_iou.mIoUforTest(gt_dir, save_dir) if args.tensorboard: scalar_info = { 'loss_seg': loss_seg_value, #'loss_seg_local': loss_seg_local_value, 'loss_adv_target': loss_adv_target_value, 'loss_local_match': loss_local_match_value, 'loss_D': loss_D_value, 'mIoU': mIoU #'loss_D_local': loss_D_local_value } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f}, loss_local_match = {5:.3f}, mIoU = {6:3f} ' #'loss_seg_local = {5:.3f} loss_adv_local = {6:.3f}, loss_D_local = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_local_match_value, mIoU) #loss_seg_local_value, loss_adv_local_value, loss_D_local_value) ) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()
def main(): """Create the model and start the training.""" or_nyu_dict = { 0: 255, 1: 16, 2: 40, 3: 39, 4: 7, 5: 14, 6: 39, 7: 12, 8: 38, 9: 40, 10: 10, 11: 6, 12: 40, 13: 39, 14: 39, 15: 40, 16: 18, 17: 40, 18: 4, 19: 40, 20: 40, 21: 5, 22: 40, 23: 40, 24: 30, 25: 36, 26: 38, 27: 40, 28: 3, 29: 40, 30: 40, 31: 9, 32: 38, 33: 40, 34: 40, 35: 40, 36: 34, 37: 37, 38: 40, 39: 40, 40: 39, 41: 8, 42: 3, 43: 1, 44: 2, 45: 22 } or_nyu_map = lambda x: or_nyu_dict.get(x, x) - 1 or_nyu_map = np.vectorize(or_nyu_map) device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True args.or_nyu_map = or_nyu_map # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) elif args.restore_from == "": saved_state_dict = None else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 40 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True if args.mode != "baseline" and args.mode != "baseline_tar": # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) scale_min = 0.5 scale_max = 2.0 rotate_min = -10 rotate_max = 10 ignore_label = 255 value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] args.width = w args.height = h train_transform = transforms.Compose([ # et.ExtResize( 512 ), transforms.RandScale([scale_min, scale_max]), transforms.RandRotate([rotate_min, rotate_max], padding=IMG_MEAN_RGB, ignore_label=ignore_label), transforms.RandomGaussianBlur(), transforms.RandomHorizontalFlip(), transforms.Crop([args.height + 1, args.width + 1], crop_type='rand', padding=IMG_MEAN_RGB, ignore_label=ignore_label), #et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), #et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), #et.ExtRandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]), ]) val_transform = transforms.Compose([ # et.ExtResize( 512 ), transforms.Crop([args.height + 1, args.width + 1], crop_type='center', padding=IMG_MEAN_RGB, ignore_label=ignore_label), transforms.ToTensor(), transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]), ]) if args.mode != "baseline_tar": src_train_dst = OpenRoomsSegmentation(root=args.data_dir, opt=args, split='train', transform=train_transform, imWidth=args.width, imHeight=args.height, remap_labels=args.or_nyu_map) else: src_train_dst = NYU_Labelled(root=args.data_dir_target, opt=args, split='train', transform=train_transform, imWidth=args.width, imHeight=args.height, phase="TRAIN", randomize=True) tar_train_dst = NYU(root=args.data_dir_target, opt=args, split='train', transform=train_transform, imWidth=args.width, imHeight=args.height, phase="TRAIN", randomize=True, mode=args.mode) tar_val_dst = NYU(root=args.data_dir, opt=args, split='val', transform=val_transform, imWidth=args.width, imHeight=args.height, phase="TRAIN", randomize=False) trainloader = data.DataLoader(src_train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(tar_train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() if args.mode != "baseline" and args.mode != "baseline_tar": optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1] + 1, input_size[0] + 1), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1] + 1, input_size_target[0] + 1), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_seg_value1_tar = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_seg_value2_tar = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) if args.mode != "baseline" and args.mode != "baseline_tar": optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) sample_src = None sample_tar = None sample_res_src = None sample_res_tar = None sample_gt_src = None sample_gt_tar = None for sub_i in range(args.iter_size): # train G if args.mode != "baseline" and args.mode != "baseline_tar": # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.__next__() images, labels, _ = batch sample_src = images.clone() sample_gt_src = labels.clone() images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) sample_pred_src = pred2.detach().cpu() loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target try: _, batch = targetloader_iter.__next__() except: targetloader_iter = enumerate(targetloader) _, batch = targetloader_iter.__next__() images, tar_labels, _, labelled = batch n_labelled = labelled.sum().detach().item() batch_size = images.shape[0] sample_tar = images.clone() sample_gt_tar = tar_labels.clone() images = images.to(device) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) #print("N_labelled {}".format(n_labelled)) if args.mode == "sda" and n_labelled != 0: labelled = labelled.to(device) == 1 tar_labels = tar_labels.to(device) loss_seg1_tar = seg_loss(pred_target1[labelled], tar_labels[labelled]) loss_seg2_tar = seg_loss(pred_target2[labelled], tar_labels[labelled]) loss_tar_labelled = loss_seg2_tar + args.lambda_seg * loss_seg1_tar loss_tar_labelled = loss_tar_labelled / args.iter_size loss_seg_value1_tar += loss_seg1_tar.item() / args.iter_size loss_seg_value2_tar += loss_seg2_tar.item() / args.iter_size else: loss_tar_labelled = torch.zeros( 1, requires_grad=True).float().to(device) # proper normalization sample_pred_tar = pred_target2.detach().cpu() if args.mode != "baseline" and args.mode != "baseline_tar": D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size + loss_tar_labelled #loss = loss_tar_labelled loss.backward() loss_adv_target_value1 += loss_adv_target1.item( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.item( ) / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, torch.FloatTensor( D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss( D_out2, torch.FloatTensor( D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() if args.mode != "baseline" and args.mode != "baseline_tar": optimizer_D1.step() optimizer_D2.step() if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, 'loss_seg1_tar': loss_seg_value1_tar, 'loss_seg2_tar': loss_seg_value2_tar, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) if i_iter % 1000 == 0: img = sample_src.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy( np.array(IMG_MEAN_RGB).reshape(1, 3, 1, 1)).float() img = img.type(torch.uint8) writer.add_images("Src/Images", img, i_iter) label = tar_train_dst.decode_target(sample_gt_src).transpose( 0, 3, 1, 2) writer.add_images("Src/Labels", label, i_iter) preds = sample_pred_src.permute(0, 2, 3, 1).cpu().numpy() preds = np.asarray(np.argmax(preds, axis=3), dtype=np.uint8) preds = tar_train_dst.decode_target(preds).transpose( 0, 3, 1, 2) writer.add_images("Src/Preds", preds, i_iter) tar_img = sample_tar.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy( np.array(IMG_MEAN_RGB).reshape( 1, 3, 1, 1)).float() tar_img = tar_img.type(torch.uint8) writer.add_images("Tar/Images", tar_img, i_iter) tar_label = tar_train_dst.decode_target( sample_gt_tar).transpose(0, 3, 1, 2) writer.add_images("Tar/Labels", tar_label, i_iter) tar_preds = sample_pred_tar.permute(0, 2, 3, 1).cpu().numpy() tar_preds = np.asarray(np.argmax(tar_preds, axis=3), dtype=np.uint8) tar_preds = tar_train_dst.decode_target(tar_preds).transpose( 0, 3, 1, 2) writer.add_images("Tar/Preds", tar_preds, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_seg1_tar={8:.3f} loss_seg2_tar={9:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, loss_seg_value1_tar, loss_seg_value2_tar)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(args.num_steps_stop) + '.pth')) if args.mode != "baseline" and args.mode != "baseline_tar": torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '.pth')) if args.mode != "baseline" and args.mode != "baseline_tar": torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()
net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') net.to(device=device) # faster convolutions, but more memory cudnn.benchmark = True from torchsummary import summary summary(net, (config.NUM_CHANNELS, config.CROP_H, config.CROP_W)) ####################### # Discriminator ####################### discriminator1 = FCDiscriminator(num_classes=config.NUM_CLASSES) discriminator1.to(device=device) discriminator2 = FCDiscriminator(num_classes=config.NUM_CLASSES) discriminator2.to(device=device) logging.info(f'Discriminator:\n' f'\t{config.NUM_CLASSES} input channels (classes)\n') from torchsummary import summary summary(discriminator1, (config.NUM_CLASSES, config.CROP_H, config.CROP_W)) summary(discriminator2, (config.NUM_CLASSES, config.CROP_H, config.CROP_W)) ####################### ####################### try:
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True bestIoU = 0 bestIter = 0 # Create network if args.model == 'ResNet': model = DeeplabMulti(num_classes=args.num_classes) saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict) if args.model == 'VGG': model = DeeplabVGG(num_classes=args.num_classes) saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict) model.train() model.to(device) cudnn.benchmark = True # init D if args.model == 'ResNet': model_D = FCDiscriminator(num_classes=256).to(device) saved_state_dict = torch.load('./snapshots/BestGTA5_D.pth') model_D.load_state_dict(saved_state_dict) if args.model == 'VGG': model_D = FCDiscriminator(num_classes=256).to(device) model_D.train() model_D.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSetLabel( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) test_interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # load calculated class center for initilization class_center_source_ori = np.load('./source_center.npy') class_center_source_ori = torch.from_numpy(class_center_source_ori) class_center_target_ori = np.load('./target_center.npy') class_center_target_ori = torch.from_numpy(class_center_target_ori) # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg = 0 loss_adv_target_value = 0 loss_D_value = 0 loss_cla_value = 0 loss_square_value = 0 loss_st_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = images.to(device) labels_s = labels # copy for center calculation labels = labels.long().to(device) feature, prediction = model(images) feature_s = feature # copy for center calculation prediction = interp(prediction) loss = seg_loss(prediction, labels) loss.backward(retain_graph=True) loss_seg = loss.item() # train with target _, batch = targetloader_iter.__next__() images, labels_pseudo, _, _ = batch labels_t = labels_pseudo # copy for center calculation images = images.to(device) labels_pseudo = labels_pseudo.long().to(device) feature_target, pred_target = model(images) feature_t = feature_target # copy for center calculation _, D_out = model_D(feature_target) loss_adv_target = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) #print(args.lambda_adv_target) loss = args.lambda_adv_target * loss_adv_target loss.backward(retain_graph=True) loss_adv_target_value = loss_adv_target.item() pred_target = interp_target(pred_target) loss_st = seg_loss(pred_target, labels_pseudo) loss_st.backward(retain_graph=True) loss_st_value = loss_st.item() # class center alignment begin if i_iter > 10000: class_center_source = class_center_cal(feature_s, labels_s) class_center_target = class_center_cal(feature_t, labels_t) class_center_source_ori = class_center_update( class_center_source, class_center_source_ori, args.lambda_center_update) class_center_target_ori = class_center_update( class_center_target, class_center_target_ori, args.lambda_center_update) class_center_source_ori = class_center_source_ori.detach( ) #align target center to source center_diff = class_center_source_ori - class_center_target_ori loss_square = torch.pow(center_diff, 2).sum() loss = args.lambda_center * loss_square loss.backward() loss_square_value = loss_square.item() # class center alignment end # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with source feature = feature.detach() cla, D_out = model_D(feature) cla = interp(cla) loss_cla = seg_loss(cla, labels) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(source_label).to(device)) loss_D = loss_D / 2 #print(args.lambda_s) loss_Disc = args.lambda_s * loss_cla + loss_D loss_Disc.backward() loss_cla_value = loss_cla.item() loss_D_value = loss_D.item() # train with target feature_target = feature_target.detach() _, D_out = model_D(feature_target) loss_D = bce_loss( D_out, torch.FloatTensor( D_out.data.size()).fill_(target_label).to(device)) loss_D = loss_D / 2 loss_D.backward() loss_D_value += loss_D.item() optimizer.step() optimizer_D.step() class_center_target_ori = class_center_target_ori.detach() if args.tensorboard: scalar_info = { 'loss_seg': loss_seg, 'loss_cla': loss_cla_value, 'loss_adv_target': loss_adv_target_value, 'loss_st_value': loss_st_value, 'loss_D': loss_D_value, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) #print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f} loss_st = {6:.5f} loss_square = {7:.5f}' .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value, loss_D_value, loss_cla_value, loss_st_value, loss_square_value)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') if not os.path.exists(args.save): os.makedirs(args.save) testloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target_test, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set='val'), batch_size=1, shuffle=False, pin_memory=True) model.eval() for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % index) image, _, name = batch with torch.no_grad(): output1, output2 = model(Variable(image).to(device)) output = test_interp(output2).cpu().data[0].numpy() output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) output = Image.fromarray(output) name = name[0].split('/')[-1] output.save('%s/%s' % (args.save, name)) mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'), args.save, 'dataset/cityscapes_list') mIoU = round(np.nanmean(mIoUs) * 100, 2) print('===> current mIoU: ' + str(mIoU)) print('===> last best mIoU: ' + str(bestIoU)) print('===> last best iter: ' + str(bestIter)) if mIoU > bestIoU: bestIoU = mIoU bestIter = i_iter torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'BestGTA5.pth')) torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'BestGTA5_D.pth')) torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) model.train() if args.tensorboard: writer.close()