def CreateTrgDataLoader(args): if args.source == 'triangle': # for simple triangle dataset target_dataset = triangleDatasetTgt(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.batch_size, crop_size=image_sizes['triangle'], mean=IMG_MEAN) target_dataloader = data.DataLoader(target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) return target_dataloader if args.data_label_folder_target is not None: target_dataset = cityscapesDataSetLabel( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.batch_size, crop_size=image_sizes['cityscapes'], mean=IMG_MEAN, set=args.set, label_folder=args.data_label_folder_target) else: if args.set == 'train': target_dataset = cityscapesDataSet( args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.batch_size, crop_size=image_sizes['cityscapes'], mean=IMG_MEAN, set=args.set) else: target_dataset = cityscapesDataSet( args.data_dir_target, args.data_list_target, crop_size=image_sizes['cityscapes'], mean=IMG_MEAN, set=args.set) if args.set == 'train': target_dataloader = data.DataLoader(target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) else: target_dataloader = data.DataLoader(target_dataset, batch_size=1, shuffle=False, pin_memory=True) return target_dataloader
def CreateTrgDataLoader(args): if args.set == 'train' or args.set == 'trainval': target_dataset = cityscapesDataSetLabel( args.data_dir_target, args.data_list_target, crop_size=image_sizes['cityscapes'], mean=IMG_MEAN, max_iters=args.num_steps * args.batch_size, set=args.set) else: target_dataset = cityscapesDataSet( args.data_dir_target, args.data_list_target, crop_size=cs_size_test['cityscapes'], mean=IMG_MEAN, set=args.set) if args.set == 'train' or args.set == 'trainval': target_dataloader = data.DataLoader(target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) else: target_dataloader = data.DataLoader(target_dataset, batch_size=1, shuffle=False, pin_memory=True) return target_dataloader
def CreateTrgDataSSLLoader(args): target_dataset = cityscapesDataSet(args.data_dir_target, args.data_list_target, crop_size=image_sizes['cityscapes'], mean=IMG_MEAN, set=args.set) target_dataloader = data.DataLoader(target_dataset, batch_size=1, shuffle=False, pin_memory=True) return target_dataloader
def CreateTrgDataSSLLoader(args): """ Create the Target data loader for SSL training Args: commandline arguments Returns: torch.utils.data.DataLoader """ target_dataset = cityscapesDataSet(args.data_dir_target, args.data_list_target, crop_size=image_sizes['cityscapes'], mean=IMG_MEAN, set=args.set) target_dataloader = data.DataLoader(target_dataset, batch_size=1, shuffle=False, pin_memory=True) return target_dataloader
def main(): ''' Create the model and start the training. ''' # Device 설정 device = torch.device("cuda" if not args.cpu else "cpu") # Source와 Target 모두 1280 * 720으로 resizing w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # 모델 생성 if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http' : # 미리 학습된 weight를 다운로드 saved_state_dict = model_zoo.load_url(args.restore_from) else: # pth 파일을 직접 설정할 경우 saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) model.train() model.to(device) cudnn.benchmark = True # Discriminator 생성 model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device) model_D1.train() model_D1.to(device) model_D2.train() model_D2.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) trainloader = data.DataLoader( GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=input_size_target, scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) targetloader_iter = enumerate(targetloader) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255) interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) for i_iter in range(args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source _, batch = trainloader_iter.__next__() images, labels, _, _, domainess = batch adw = torch.sqrt(1-domainess).float() adw.requires_grad = False images = images.to(device) labels = labels.long().to(device) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = seg_loss(pred1, labels) loss_seg2 = seg_loss(pred2, labels) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1 += loss_seg1.item() / args.iter_size loss_seg_value2 += loss_seg2.item() / args.iter_size # train with target _, batch = targetloader_iter.__next__() images, _, _ = batch images = images.to(device) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_adv_target1 = adw*bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) loss_adv_target2 = adw*bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1)) D_out2 = model_D2(F.softmax(pred2)) loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1)) D_out2 = model_D2(F.softmax(pred_target2)) loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device)) loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.item() loss_D_value2 += loss_D2.item() optimizer.step() optimizer_D1.step() optimizer_D2.step() if args.tensorboard: scalar_info = { 'loss_seg1': loss_seg_value1, 'loss_seg2': loss_seg_value2, 'loss_adv_target1': loss_adv_target_value1, 'loss_adv_target2': loss_adv_target_value2, 'loss_D1': loss_D_value1, 'loss_D2': loss_D_value2, } if i_iter % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, i_iter) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format( i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth')) torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth')) if args.tensorboard: writer.close()