def main(): """Create the model and start the training.""" h, w = map(int, args.input_size.split(',')) input_size = (h, w) h, w = map(int, args.input_size_target.split(',')) input_size_target = (h, w) cudnn.enabled = True from pytorchgo.utils.pytorch_utils import set_gpu set_gpu(args.gpu) # Create network if args.model == 'DeepLab': logger.info("adopting Deeplabv2 base model..") model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) elif args.model == "FCN8S": logger.info("adopting FCN8S base model..") from pytorchgo.model.MyFCN8s import MyFCN8s model = MyFCN8s(n_class=NUM_CLASSES) vgg16 = torchfcn.models.VGG16(pretrained=True) model.copy_params_from_vgg16(vgg16) optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) else: raise ValueError model.train() model.cuda() cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda() model_D2.train() model_D2.cuda() if SOURCE_DATA == "GTA5": trainloader = data.DataLoader(GTA5DataSet( args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) elif SOURCE_DATA == "SYNTHIA": trainloader = data.DataLoader(SynthiaDataSet( args.data_dir, args.data_list, LABEL_LIST_PATH, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) else: raise ValueError targetloader = data.DataLoader(cityscapesDataSet( max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear') # labels for adversarial training source_label = 0 target_label = 1 best_mIoU = 0 model_summary([model, model_D1, model_D2]) optimizer_summary([optimizer, optimizer_D1, optimizer_D2]) for i_iter in tqdm(range(args.num_steps_stop), total=args.num_steps_stop, desc="training"): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() lr = adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter) lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): ######################### train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() pred2 = model(images) pred2 = interp(pred2) loss_seg2 = loss_calc(pred2, labels) loss = loss_seg2 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size # train with target _, batch = targetloader_iter.next() images, _, _, _ = batch images = Variable(images).cuda() pred_target2 = model(images) pred_target2 = interp_target(pred_target2) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda()) loss = args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy( )[0] / args.iter_size ################################## train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred2 = pred2.detach() D_out2 = model_D2(F.softmax(pred2)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda()) loss_D2 = loss_D2 / args.iter_size / 2 loss_D2.backward() loss_D_value2 += loss_D2.data.cpu().numpy()[0] # train with target pred_target2 = pred_target2.detach() D_out2 = model_D2(F.softmax(pred_target2)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda()) loss_D2 = loss_D2 / args.iter_size / 2 loss_D2.backward() loss_D_value2 += loss_D2.data.cpu().numpy()[0] optimizer.step() optimizer_D1.step() optimizer_D2.step() if i_iter % 100 == 0: logger.info( 'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}' .format(i_iter, args.num_steps_stop, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, lr, lr_D1, best_mIoU)) if i_iter % args.save_pred_every == 0 and i_iter != 0: logger.info("saving snapshot.....") cur_miou16 = proceed_test(model, input_size) is_best = True if best_mIoU < cur_miou16 else False if is_best: best_mIoU = cur_miou16 torch.save( { 'iteration': i_iter, 'optim_state_dict': optimizer.state_dict(), 'optim_D1_state_dict': optimizer_D1.state_dict(), 'optim_D2_state_dict': optimizer_D2.state_dict(), 'model_state_dict': model.state_dict(), 'model_D1_state_dict': model_D1.state_dict(), 'model_D2_state_dict': model_D2.state_dict(), 'best_mean_iu': cur_miou16, }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar')) if is_best: import shutil shutil.copy( osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'), osp.join(logger.get_logger_dir(), 'model_best.pth.tar')) if i_iter >= args.num_steps_stop - 1: break
def main(): """Create the model and start the training.""" h, w = map(int, args.input_size.split(',')) input_size = (h, w) h, w = map(int, args.input_size_target.split(',')) input_size_target = (h, w) cudnn.enabled = True gpu = args.gpu # Create network if args.model == 'DeepLab': model = Res_Deeplab(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda(args.gpu) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader(SynthiaDataSet( args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(CamvidDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() bce_loss = torch.nn.BCEWithLogitsLoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size # train with target _, batch = targetloader_iter.next() images, _, _ = batch images = Variable(images).cuda(args.gpu) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda( args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda( args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(target_label)).cuda( args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda( args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print 'save model ...' torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print 'taking snapshot ...' torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main(): """Create the model and start the training.""" args = get_arguments() w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) # Create network student_net = FCN8s(args.num_classes, args.model_path_prefix) student_net = torch.nn.DataParallel(student_net) student_net = student_net.cuda() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) train_joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), ]) input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) target_transform = extended_transforms.MaskToTensor() # show img restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.Lambda(lambda x: x.div_(255)), standard_transforms.ToPILImage(), extended_transforms.FlipChannels(), ]) visualize = standard_transforms.ToTensor() if 'syn' in args.data_dir: src_dataset = SynthiaDataSet( args.data_dir, args.data_list, joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform, ) else: src_dataset = Cityscapes16DataSetLMDB( args.data_dir, args.data_list, joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform, ) src_loader = data.DataLoader( src_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) tgt_val_dataset = Cityscapes16DataSetLMDB( args.data_dir_target, args.data_list_target, # no val resize # joint_transform=val_joint_transform, transform=val_input_transform, target_transform=target_transform, ) tgt_val_loader = data.DataLoader( tgt_val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, ) optimizer = optim.SGD( student_net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay ) # optimizer = optim.Adam( # student_net.parameters(), lr=args.learning_rate, # weight_decay=args.weight_decay # ) student_params = list(student_net.parameters()) # interp = partial( # nn.functional.interpolate, # size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True # ) # interp_tgt = partial( # nn.functional.interpolate, # size=(h_target, w_target), mode='bilinear', align_corners=True # ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear') n_class = args.num_classes # src_criterion = torch.nn.CrossEntropyLoss( # ignore_index=255, reduction='sum') src_criterion = torch.nn.CrossEntropyLoss( ignore_index=255, size_average=False ) num_batches = len(src_loader) highest = 0 for epoch in range(args.num_epoch): cls_loss_rec = AverageMeter() aug_loss_rec = AverageMeter() mask_rec = AverageMeter() confidence_rec = AverageMeter() miu_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() # load_time_rec = AverageMeter() # trans_time_rec = AverageMeter() tem_time = time.time() for batch_index, src_data in enumerate(src_loader): student_net.train() optimizer.zero_grad() # train with source # src_images, src_label, src_img_name, (load_time, trans_time) = src_data src_images, src_label, src_img_name = src_data src_images = src_images.cuda() src_label = src_label.cuda() data_time_rec.update(time.time()-tem_time) src_output = student_net(src_images) # src_output = interp(src_output) # Segmentation Loss cls_loss_value = src_criterion(src_output, src_label) cls_loss_value /= src_images.shape[0] total_loss = cls_loss_value total_loss.backward() optimizer.step() _, predict_labels = torch.max(src_output, 1) lbl_pred = predict_labels.detach().cpu().numpy() lbl_true = src_label.detach().cpu().numpy() _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, args.num_classes) cls_loss_rec.update(cls_loss_value.detach_().item()) miu_rec.update(mean_iu) # load_time_rec.update(torch.mean(load_time).item()) # trans_time_rec.update(torch.mean(trans_time).item()) batch_time_rec.update(time.time()-tem_time) tem_time = time.time() if (batch_index+1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' # f'Load: {load_time_rec.avg:.2f} ' # f'Trans: {trans_time_rec.avg:.2f} ' f'Mean iu: {miu_rec.avg*100:.1f} ' f'CLS: {cls_loss_rec.avg:.2f}' ) miu = test_miou(student_net, tgt_val_loader, upsample, './dataset/info_synthia.json', num_classes=args.num_classes) if miu > highest: torch.save(student_net.module.state_dict(), osp.join( args.snapshot_dir, 'final_fcn.pth')) print(f'save highest with {miu:.2%}') highest = miu
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) src_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), ]) tgt_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() src_dataset = SynthiaDataSet( args.data_dir, args.data_list, joint_transform=joint_transform, transform=src_input_transform, target_transform=target_transform, ) src_loader = data.DataLoader(src_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) tgt_dataset = Cityscapes16DataSetLMDB( args.data_dir_target, args.data_list_target, joint_transform=joint_transform, transform=tgt_input_transform, target_transform=target_transform, ) tgt_loader = data.DataLoader(tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) val_dataset = Cityscapes16DataSetLMDB( args.data_dir_val, args.data_list_val, transform=val_input_transform, target_transform=target_transform, ) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) style_trans = StyleTrans(args, info_json='./dataset/info_synthia.json') style_trans.train(src_loader, tgt_loader, val_loader, writer) writer.close()
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) args.input_size = (w, h) w, h = map(int, args.crop_size.split(',')) args.crop_size = (h, w) w, h = map(int, args.input_size_target.split(',')) args.input_size_target = (w, h) cudnn.enabled = True cudnn.benchmark = True str_ids = args.gpu_ids.split(',') gpu_ids = [] for str_id in str_ids: gid = int(str_id) if gid >= 0: gpu_ids.append(gid) num_gpu = len(gpu_ids) args.multi_gpu = False if num_gpu > 1: args.multi_gpu = True Trainer = AD_Trainer(args) Trainer.G = torch.nn.DataParallel(Trainer.G, gpu_ids) Trainer.D1 = torch.nn.DataParallel(Trainer.D1, gpu_ids) Trainer.D2 = torch.nn.DataParallel(Trainer.D2, gpu_ids) else: Trainer = AD_Trainer(args) print(Trainer) trainloader = data.DataLoader(SynthiaDataSet( args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, resize_size=args.input_size, crop_size=args.crop_size, scale=True, mirror=True, mean=IMG_MEAN, autoaug=args.autoaug), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, resize_size=args.input_size_target, crop_size=args.crop_size, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set, autoaug=args.autoaug_target), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) targetloader_iter = enumerate(targetloader) # set up tensor board if args.tensorboard: args.log_dir += '/' + os.path.basename(args.snapshot_dir) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 adjust_learning_rate(Trainer.gen_opt, i_iter, args) adjust_learning_rate_D(Trainer.dis1_opt, i_iter, args) adjust_learning_rate_D(Trainer.dis2_opt, i_iter, args) for sub_i in range(args.iter_size): # train G # train with source _, batch = trainloader_iter.__next__() _, batch_t = targetloader_iter.__next__() images, labels, _, _ = batch images = images.cuda() labels = labels.long().cuda() images_t, labels_t, _, _ = batch_t images_t = images_t.cuda() labels_t = labels_t.long().cuda() with Timer("Elapsed time in update: %f"): loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss = Trainer.gen_update( images, images_t, labels, labels_t, i_iter) loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size loss_adv_target_value1 += loss_adv_target1 / args.iter_size loss_adv_target_value2 += loss_adv_target2 / args.iter_size loss_me_value = loss_me if args.lambda_adv_target1 > 0 and args.lambda_adv_target2 > 0: loss_D1, loss_D2 = Trainer.dis_update( pred1, pred2, pred_target1, pred_target2) loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() else: loss_D_value1 = 0 loss_D_value2 = 0 del pred1, pred2, pred_target1, pred_target2 if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_me_target': loss_me_value, 'loss_kl_target': loss_kl, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, 'val_loss': val_loss, } if i_iter % 100 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( '\033[1m iter = %8d/%8d \033[0m loss_seg1 = %.3f loss_seg2 = %.3f loss_me = %.3f loss_kl = %.3f loss_adv1 = %.3f, loss_adv2 = %.3f loss_D1 = %.3f loss_D2 = %.3f, val_loss=%.3f' % (i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_me_value, loss_kl, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, val_loss)) # clear loss del loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, val_loss if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( Trainer.G.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( Trainer.D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( Trainer.D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( Trainer.G.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save( Trainer.D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save( Trainer.D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()