def train_minent(model, trainloader, targetloader, cfg): ''' UDA training with minEnt ''' # Create the model and start the training. input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP)): # reset optimizers optimizer.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) # UDA Training # train on source _, batch = trainloader_iter.__next__() images_source, labels, _, _ = batch pred_src_aux, pred_src_main = model(images_source.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = interp(pred_src_aux) loss_seg_src_aux = loss_calc(pred_src_aux, labels, device) else: loss_seg_src_aux = 0 pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux) loss.backward() # adversarial training with minent _, batch = targetloader_iter.__next__() images, _, _, _ = batch pred_trg_aux, pred_trg_main = model(images.cuda(device)) pred_trg_aux = interp_target(pred_trg_aux) pred_trg_main = interp_target(pred_trg_main) pred_prob_trg_aux = F.softmax(pred_trg_aux) pred_prob_trg_main = F.softmax(pred_trg_main) loss_target_entp_aux = entropy_loss(pred_prob_trg_aux) loss_target_entp_main = entropy_loss(pred_prob_trg_main) loss = (cfg.TRAIN.LAMBDA_ENT_AUX * loss_target_entp_aux + cfg.TRAIN.LAMBDA_ENT_MAIN * loss_target_entp_main) loss.backward() optimizer.step() current_losses = {'loss_seg_src_aux': loss_seg_src_aux, 'loss_seg_src_main': loss_seg_src_main, 'loss_ent_aux': loss_target_entp_aux, 'loss_ent_main': loss_target_entp_main} print_losses(current_losses, i_iter) if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), osp.join(cfg.TRAIN.SNAPSHOT_DIR, f'model_{i_iter}.pth')) if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: log_losses_tensorboard(writer, current_losses, i_iter) if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1: draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T') draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')
def train_advent(model, trainloader, targetloader, cfg): ''' UDA training with advent ''' # Create the model and start the training. # pdb.set_trace() input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # DISCRIMINATOR NETWORK # feature-level d_aux = get_fc_discriminator(num_classes=num_classes) d_aux.train() d_aux.to(device) # restore_from = cfg.TRAIN.RESTORE_FROM_aux # print("Load Discriminator:", restore_from) # load_checkpoint_for_evaluation(d_aux, restore_from, device) # seg maps, i.e. output, level d_main = get_fc_discriminator(num_classes=num_classes) d_main.train() d_main.to(device) # restore_from = cfg.TRAIN.RESTORE_FROM_main # print("Load Discriminator:", restore_from) # load_checkpoint_for_evaluation(d_main, restore_from, device) # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # discriminators' optimizers optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP)): # reset optimizers optimizer.zero_grad() optimizer_d_aux.zero_grad() optimizer_d_main.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # UDA Training # only train segnet. Don't accumulate grads in disciminators for param in d_aux.parameters(): param.requires_grad = False for param in d_main.parameters(): param.requires_grad = False # train on source _, batch = trainloader_iter.__next__() images_source, labels, _, _ = batch # debug: # labels=labels.numpy() # from matplotlib import pyplot as plt # import numpy as np # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show() pred_src_aux, pred_src_main = model(images_source.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = interp(pred_src_aux) loss_seg_src_aux = loss_calc(pred_src_aux, labels, device) else: loss_seg_src_aux = 0 pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) # pdb.set_trace() loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux) loss.backward() # adversarial training ot fool the discriminator _, batch = targetloader_iter.__next__() images, _, _, _ = batch pred_trg_aux, pred_trg_main = model(images.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = interp_target(pred_trg_aux) d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_adv_trg_aux = bce_loss(d_out_aux, source_label) else: loss_adv_trg_aux = 0 pred_trg_main = interp_target(pred_trg_main) d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main + cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux) loss = loss loss.backward() # Train discriminator networks # enable training mode on discriminator networks for param in d_aux.parameters(): param.requires_grad = True for param in d_main.parameters(): param.requires_grad = True # train with source if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = pred_src_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux))) loss_d_aux = bce_loss(d_out_aux, source_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() pred_src_main = pred_src_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main))) loss_d_main = bce_loss(d_out_main, source_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() # train with target if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = pred_trg_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_d_aux = bce_loss(d_out_aux, target_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() else: loss_d_aux = 0 pred_trg_main = pred_trg_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_d_main = bce_loss(d_out_main, target_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() optimizer.step() if cfg.TRAIN.MULTI_LEVEL: optimizer_d_aux.step() optimizer_d_main.step() current_losses = {'loss_seg_src_aux': loss_seg_src_aux, 'loss_seg_src_main': loss_seg_src_main, 'loss_adv_trg_aux': loss_adv_trg_aux, 'loss_adv_trg_main': loss_adv_trg_main, 'loss_d_aux': loss_d_aux, 'loss_d_main': loss_d_main} print_losses(current_losses, i_iter) if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth') torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth') torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth') if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: log_losses_tensorboard(writer, current_losses, i_iter) if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1: draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T') draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')
def train_preview(model, source_loader, target_loader, cfg, comet_exp): # UDA TRAINING ''' UDA training with advent ''' # Create the model and start the training. input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # DISCRIMINATOR NETWORK # feature-level d_aux = get_fc_discriminator(num_classes=num_classes) d_aux.train() d_aux.to(device) # seg maps, i.e. output, level d_main = get_fc_discriminator(num_classes=num_classes) d_main.train() d_main.to(device) # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # discriminators' optimizers optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 times = deque([0], maxlen=100) model_times = deque([0], maxlen=100) source_loader_iter = enumerate(source_loader) target_loader_iter = enumerate(target_loader) cur_best_miou = -1 cur_best_model = '' for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)): times.append(time()) comet_exp.log_metric("i_iter", i_iter) comet_exp.log_metric("target_epoch", i_iter / len(target_loader)) comet_exp.log_metric("source_epoch", i_iter / len(source_loader)) # reset optimizers optimizer.zero_grad() optimizer_d_aux.zero_grad() optimizer_d_main.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # UDA Training # only train segnet. Don't accumulate grads in disciminators for param in d_aux.parameters(): param.requires_grad = False for param in d_main.parameters(): param.requires_grad = False # train on source try: _, batch_and_path = source_loader_iter.__next__() except StopIteration: source_loader_iter = enumerate(source_loader) _, batch_and_path = source_loader_iter.__next__() images_source, labels = batch_and_path['data']['x'], batch_and_path[ 'data']['m'] pred_src_aux, pred_src_main = model(images_source.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = interp(pred_src_aux) loss_seg_src_aux = loss_calc(pred_src_aux, labels, device) else: loss_seg_src_aux = 0 pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux) loss.backward() # adversarial training to fool the discriminator try: _, batch = target_loader_iter.__next__() except StopIteration: target_loader_iter = enumerate(target_loader) _, batch = target_loader_iter.__next__() images = batch['data']['x'] pred_trg_aux, pred_trg_main = model(images.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = interp_target(pred_trg_aux) d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_adv_trg_aux = bce_loss(d_out_aux, source_label) else: loss_adv_trg_aux = 0 pred_trg_main = interp_target(pred_trg_main) d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main + cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux) loss = loss loss.backward() # Train discriminator networks # enable training mode on discriminator networks for param in d_aux.parameters(): param.requires_grad = True for param in d_main.parameters(): param.requires_grad = True # train with source if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = pred_src_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux))) loss_d_aux = bce_loss(d_out_aux, source_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() pred_src_main = pred_src_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main))) loss_d_main = bce_loss(d_out_main, source_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() # train with target if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = pred_trg_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_d_aux = bce_loss(d_out_aux, target_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() else: loss_d_aux = 0 pred_trg_main = pred_trg_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_d_main = bce_loss(d_out_main, target_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() optimizer.step() if cfg.TRAIN.MULTI_LEVEL: optimizer_d_aux.step() optimizer_d_main.step() model_times.append(time() - times[-1]) mod_times = np.mean(model_times) comet_exp.log_metric("model_time", mod_times) current_losses = { 'loss_seg_src_aux': loss_seg_src_aux, 'loss_seg_src_main': loss_seg_src_main, 'loss_adv_trg_aux': loss_adv_trg_aux, 'loss_adv_trg_main': loss_adv_trg_main, 'loss_d_aux': loss_d_aux, 'loss_d_main': loss_d_main } print_losses(current_losses, i_iter) current_losses_numDict = tesnorDict2numDict(current_losses) comet_exp.log_metrics(current_losses_numDict) if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth') torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth') torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth') if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break if i_iter % cfg.TRAIN.SAVE_IMAGE_PRED == 0 and i_iter != 0 or i_iter == cfg.TRAIN.EARLY_STOP: print("Inferring test images in iteration {}...".format(i_iter)) hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES)) image, label = batch['data']['x'][0], batch['data']['m'][0] image = image[None, :, :, :] interp = nn.Upsample(size=(label.shape[1], label.shape[2]), mode='bilinear', align_corners=True) with torch.no_grad(): pred_main = model(image.cuda(device))[1] output = interp(pred_main).cpu().data[0].numpy() output = output.transpose(1, 2, 0) output = np.argmax(output, axis=2) label0 = label.numpy()[0] hist += fast_hist(label0.flatten(), output.flatten(), cfg.NUM_CLASSES) output = torch.tensor(output, dtype=torch.float32) output = output[None, :, :] output_RGB = output.repeat(3, 1, 1) if i_iter % 100 == 0: print('{:d} / {:d}: {:0.2f}'.format( i_iter % len(target_loader), len(target_loader), 100 * np.nanmean(per_class_iu(hist)))) inters_over_union_classes = per_class_iu(hist) computed_miou = round( np.nanmean(inters_over_union_classes) * 100, 2) if cur_best_miou < computed_miou: cur_best_miou = computed_miou cur_best_model = f'model_{i_iter}.pth' print('\tCurrent mIoU:', computed_miou) print('\tCurrent best model:', cur_best_model) print('\tCurrent best mIoU:', cur_best_miou) mious = { 'Current mIoU': computed_miou, 'Current best model': cur_best_model, 'Current best mIoU': cur_best_miou } comet_exp.log_metrics(mious) image = image[0] # change size from [1,x,y,z] to [x,y,z] save_images = [] save_images.append(image) # Overlay mask: save_mask = (image - (image * label.repeat(3, 1, 1)) + label.repeat(3, 1, 1)) save_fake_mask = (image - (image * output_RGB) + output_RGB) save_images.append(save_mask) save_images.append(save_fake_mask) save_images.append(label.repeat(3, 1, 1)) save_images.append(output_RGB) write_images(save_images, i_iter, comet_exp=comet_exp, store_im=cfg.TEST.store_images)
def train_advent(model, trainloader, targetloader, cfg): ''' UDA training with advent ''' # Create the model and start the training. input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # DISCRIMINATOR NETWORK # feature-level # d_aux = get_fc_discriminator(num_classes=num_classes) d_aux = get_fe_discriminator(num_classes=1024) # saved_state_dict_D1 = torch.load('C:\\Users\\Administrator\\OneDrive - University of Ottawa\\Python\\ADVENT-master\experiments\\snapshots\\GTA2Cityscapes_DeepLabv2_AdvEnt413\\model_125000_D_aux.pth') # d_aux.load_state_dict(saved_state_dict_D1) d_aux.train() d_aux.to(device) # seg maps, i.e. output, level d_main = get_fc_discriminator(num_classes=num_classes) # saved_state_dict_D2 = torch.load('C:\\Users\\Administrator\\OneDrive - University of Ottawa\\Python\\ADVENT-master\\experiments\\snapshots\\GTA2Cityscapes_DeepLabv2_AdvEnt413\\model_125000_D_main.pth') # d_main.load_state_dict(saved_state_dict_D2) d_main.train() d_main.to(device) # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # discriminators' optimizers optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # interp_aux = nn.Upsample(size=(128, 256), mode='bilinear', align_corners=True) # H/4 # interp_aux_source = nn.Upsample(size=(180, 320), mode='bilinear', align_corners=True) # H/4 weighted_bce_loss = WeightedBCEWithLogitsLoss() criterion_seg = nn.CrossEntropyLoss(ignore_index=255) # labels for adversarial training source_label = 0 target_label = 1 Epsilon = 0.1 Lambda_local = 1 trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)): # reset optimizers optimizer.zero_grad() optimizer_d_aux.zero_grad() optimizer_d_main.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) damping = (1 - i_iter / 100000) ### UDA Training # only train segnet. Don't accumulate grads in disciminators for param in d_aux.parameters(): param.requires_grad = False for param in d_main.parameters(): param.requires_grad = False # train on source _, batch = trainloader_iter.__next__() images_source, labels, _, _ = batch pred_src_aux, pred_src_main = model( images_source.cuda(device) ) # H/8 multi-level outputs coming from both conv4 and conv5 # pred_src_aux = interp_aux_source(pred_src_aux) # H/4=1280/4 loss_seg_src_aux = 0 # if cfg.TRAIN.MULTI_LEVEL: # pred_src_aux = interp(pred_src_aux) # loss_seg_src_aux = loss_calc(pred_src_aux, labels, device) # # pred_src_aux = F.softmax(pred_src_aux1) # else: # loss_seg_src_aux = 0 pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux) loss.backward() # adversarial training ot fool the discriminator _, batch = targetloader_iter.__next__() images, _, _, _ = batch pred_trg_aux, pred_trg_main = model( images.cuda(device)) # H/8=120, H/8=129 # pred_trg_aux = interp_aux(pred_trg_aux) # H/4=256 pred_trg_main_0 = interp_target(pred_trg_main) pred_trg_main = F.softmax(pred_trg_main_0) def toweight(x): x = x.cpu().data[0][0] x = preprocessing.scale(x) x = 1 / (1 + np.exp(-x)) x = x * 1.5 x = torch.tensor(x, dtype=torch.float32, device=device) return x if cfg.TRAIN.MULTI_LEVEL: # pred_trg_aux = F.softmax(interp_target(pred_trg_aux)) # d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) # -p*log(p) d_out_aux = interp_target(d_aux(pred_trg_aux)) # H/8->H/8->H loss_adv_trg_aux = 0 ones = torch.ones_like(d_out_aux) zero = torch.zeros_like(d_out_aux) # if (i_iter > 5000): # pred_trg_aux_conf = 1.0 - torch.max(pred_trg_aux, 1)[0] # weight_map_aux = torch.unsqueeze(pred_trg_aux_conf, dim=0) # loss_adv_trg_aux = weighted_bce_loss(d_out_aux, Variable(torch.FloatTensor(d_out_aux.data.size()).fill_(source_label).to(device)), # weight_map_aux, Epsilon , Lambda_local) # else: # loss_adv_trg_aux = bce_loss(d_out_aux, source_label) else: loss_adv_trg_aux = 0 # pred_trg_main = F.softmax(interp_target(pred_trg_main)) # H/8->H d_out_main = interp_target(d_main(pred_trg_main)) # H->H/8->H # loss_adv_trg_main = bce_loss(d_out_main, source_label) if (i_iter > 5000): maxpred, label = torch.max(pred_trg_main.detach(), dim=1) mask = (maxpred > 0.90) label = torch.where( mask, label, torch.ones(1).to(device, dtype=torch.long) * 255) loss_seg_trg_main = criterion_seg(pred_trg_main_0, label) # loss_seg_trg_main_.backward() pred_trg_main_conf = 1.0 - torch.max(pred_trg_main, 1)[0] fweight = toweight(d_out_aux) # pred_trg_main_conf = 1 - torch.max(pred_trg_main.detach(), 1)[0] # fweight = toweight(d_out_aux.detach()) weight_map_main = pred_trg_main_conf * fweight weight_map_main = torch.where(weight_map_main > 1, ones, weight_map_main) weight_map_main = torch.where(weight_map_main < 0.05, zero, weight_map_main) # weight_map_main = torch.unsqueeze(weight_map_main, dim=0) loss_adv_trg_main = weighted_bce_loss( d_out_main, Variable( torch.FloatTensor(d_out_main.data.size()).fill_( source_label).to(device)), weight_map_main, Epsilon, Lambda_local) else: loss_adv_trg_main = bce_loss(d_out_main, source_label) loss_seg_trg_main = 0 loss = cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main * damping loss.backward() ### Train discriminator networks # enable training mode on discriminator networks for param in d_aux.parameters(): param.requires_grad = True for param in d_main.parameters(): param.requires_grad = True # train with source if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = pred_src_aux.detach() # d_out_aux = interp(d_aux(F.softmax(pred_src_aux))) # -plog(p) d_out_aux = interp(d_aux(pred_src_aux)) # H/8->H/8->H # d_out_aux = d_aux(prob_2_entropy(pred_src_aux)) loss_d_aux = bce_loss(d_out_aux, source_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() pred_src_main = pred_src_main.detach() d_out_main = interp(d_main(F.softmax(pred_src_main))) # H->H/8->H # d_out_main = d_main(prob_2_entropy(pred_src_main)) loss_d_main = bce_loss(d_out_main, source_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() # train with target if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = pred_trg_aux.detach() # d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) d_out_aux = interp_target(d_aux(pred_trg_aux)) # H/8->H/8->H loss_d_aux = bce_loss(d_out_aux, target_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() # if (i_iter > 5000): # weight_map_aux = weight_map_aux.detach() # loss_d_aux = weighted_bce_loss(d_out_aux, Variable(torch.FloatTensor(d_out_aux.data.size()).fill_(target_label).to(device)), # weight_map_aux, Epsilon, Lambda_local) # else: # loss_d_aux = bce_loss(d_out_aux, target_label) else: loss_d_aux = 0 pred_trg_main = pred_trg_main.detach() d_out_main = interp_target(d_main(pred_trg_main)) # loss_d_main = bce_loss(d_out_main, target_label) if (i_iter > 5000): pred_trg_main_conf = pred_trg_main_conf.detach() fweight = toweight(d_out_aux) # fweight = toweight(d_out_aux.detach()) weight_map_main = pred_trg_main_conf * fweight weight_map_main = torch.where(weight_map_main > 1, ones, weight_map_main) # weight_map_main = torch.unsqueeze(weight_map_main, dim=0) loss_d_main = weighted_bce_loss( d_out_main, Variable( torch.FloatTensor(d_out_main.data.size()).fill_( target_label).to(device)), weight_map_main, Epsilon, Lambda_local) else: loss_d_main = bce_loss(d_out_main, target_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() optimizer.step() if cfg.TRAIN.MULTI_LEVEL: optimizer_d_aux.step() optimizer_d_main.step() current_losses = { 'loss_seg_trg_main': loss_seg_trg_main, 'loss_seg_src_main': loss_seg_src_main, 'loss_adv_trg_aux': loss_adv_trg_aux, 'loss_adv_trg_main': loss_adv_trg_main, 'loss_d_aux': loss_d_aux, 'loss_d_main': loss_d_main } print_losses(current_losses, i_iter) if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth') torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth') torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth') if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: log_losses_tensorboard(writer, current_losses, i_iter) if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1: draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T') draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')
def train_dada(model, trainloader, targetloader, cfg): """ UDA training with dada """ # Create the model and start the training. input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # DISCRIMINATOR NETWORK # seg maps, i.e. output, level d_main = get_fc_discriminator(num_classes=num_classes) d_main.train() d_main.to(device) # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD( model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY, ) # discriminators' optimizers optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) # interpolate output segmaps interp = nn.Upsample( size=(input_size_source[1], input_size_source[0]), mode="bilinear", align_corners=True, ) interp_target = nn.Upsample( size=(input_size_target[1], input_size_target[0]), mode="bilinear", align_corners=True, ) # labels for adversarial training source_label = 0 target_label = 1 trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)): # reset optimizers optimizer.zero_grad() optimizer_d_main.zero_grad() adjust_learning_rate(optimizer, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # UDA Training # only train segnet. Don't accumulate grads in disciminators for param in d_main.parameters(): param.requires_grad = False # train on source _, batch = trainloader_iter.__next__() images_source, labels, depth, _, _ = batch _, pred_src_main, pred_depth_src_main = model( images_source.cuda(device)) pred_src_main = interp(pred_src_main) pred_depth_src_main = interp(pred_depth_src_main) loss_depth_src_main = loss_calc_depth(pred_depth_src_main, depth, device) loss_seg_src_main = loss_calc(pred_src_main, labels, device) loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_DEPTH_MAIN * loss_depth_src_main) loss.backward() # adversarial training ot fool the discriminator _, batch = targetloader_iter.__next__() images, _, _, _ = batch _, pred_trg_main, pred_depth_trg_main = model(images.cuda(device)) pred_trg_main = interp_target(pred_trg_main) pred_depth_trg_main = interp_target(pred_depth_trg_main) d_out_main = d_main( prob_2_entropy(F.softmax(pred_trg_main)) * pred_depth_trg_main) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss = cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main loss.backward() # Train discriminator networks # enable training mode on discriminator networks for param in d_main.parameters(): param.requires_grad = True # train with source pred_src_main = pred_src_main.detach() pred_depth_src_main = pred_depth_src_main.detach() d_out_main = d_main( prob_2_entropy(F.softmax(pred_src_main)) * pred_depth_src_main) loss_d_main = bce_loss(d_out_main, source_label) loss_d_main = loss_d_main loss_d_main.backward() # train with target pred_trg_main = pred_trg_main.detach() pred_depth_trg_main = pred_depth_trg_main.detach() d_out_main = d_main( prob_2_entropy(F.softmax(pred_trg_main)) * pred_depth_trg_main) loss_d_main = bce_loss(d_out_main, target_label) loss_d_main = loss_d_main loss_d_main.backward() optimizer.step() optimizer_d_main.step() current_losses = { "loss_seg_src_main": loss_seg_src_main, "loss_depth_src_main": loss_depth_src_main, "loss_adv_trg_main": loss_adv_trg_main, "loss_d_main": loss_d_main, } print_losses(current_losses, i_iter) if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print("taking snapshot ...") print("exp =", cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f"model_{i_iter}.pth") torch.save(d_main.state_dict(), snapshot_dir / f"model_{i_iter}_D_main.pth") if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: log_losses_tensorboard(writer, current_losses, i_iter) if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1: draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, "T") draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, "S")
def train_advent(model, trainloader, targetloader, cfg, args): ''' UDA training with advent ''' # Create the model and start the training. # pdb.set_trace() input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32) SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1)) device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # -------------------------------------------------------- # # codes to initialize wandb for storing logs on its cloud wandb.init(project='FDA_integration_to_INTRA_DA') wandb.config.update(args) for key, val in cfg.items(): wandb.config.update({key: val}) wandb.watch(model) # -------------------------------------------------------- # # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # DISCRIMINATOR NETWORK # feature-level d_aux = get_fc_discriminator(num_classes=num_classes) d_aux.train() d_aux.to(device) # restore_from = cfg.TRAIN.RESTORE_FROM_aux # print("Load Discriminator:", restore_from) # load_checkpoint_for_evaluation(d_aux, restore_from, device) # seg maps, i.e. output, level d_main = get_fc_discriminator(num_classes=num_classes) d_main.train() d_main.to(device) # restore_from = cfg.TRAIN.RESTORE_FROM_main # print("Load Discriminator:", restore_from) # load_checkpoint_for_evaluation(d_main, restore_from, device) # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # discriminators' optimizers optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)): # reset optimizers optimizer.zero_grad() optimizer_d_aux.zero_grad() optimizer_d_main.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # UDA Training # only train segnet. Don't accumulate grads in disciminators for param in d_aux.parameters(): param.requires_grad = False for param in d_main.parameters(): param.requires_grad = False _, batch = trainloader_iter.__next__() images_source, labels, _, _ = batch _, batch = targetloader_iter.__next__() images, _, _, _ = batch # ----------------------------------------------------------------# B, C, H, W = images_source.shape mean_images_source = SRC_IMG_MEAN.repeat(B, 1, H, W) mean_images = SRC_IMG_MEAN.repeat(B, 1, H, W) if args.FDA_mode == 'on': # normalize the source and target image images_source -= mean_images_source images -= mean_images elif args.FDA_mode == 'off': # Keep source and target images as they are # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off' images_source = images_source images = images else: raise KeyError() # ----------------------------------------------------------------# # debug: # labels=labels.numpy() # from matplotlib import pyplot as plt # import numpy as np # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show() # train on source pred_src_aux, pred_src_main = model(images_source.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = interp(pred_src_aux) loss_seg_src_aux = loss_calc(pred_src_aux, labels, device) else: loss_seg_src_aux = 0 pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) # pdb.set_trace() loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux) loss.backward() # adversarial training ot fool the discriminator pred_trg_aux, pred_trg_main = model(images.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = interp_target(pred_trg_aux) d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_adv_trg_aux = bce_loss(d_out_aux, source_label) else: loss_adv_trg_aux = 0 pred_trg_main = interp_target(pred_trg_main) d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main + cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux) loss = loss loss.backward() # Train discriminator networks # enable training mode on discriminator networks for param in d_aux.parameters(): param.requires_grad = True for param in d_main.parameters(): param.requires_grad = True # train with source if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = pred_src_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux))) loss_d_aux = bce_loss(d_out_aux, source_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() pred_src_main = pred_src_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main))) loss_d_main = bce_loss(d_out_main, source_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() # train with target if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = pred_trg_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_d_aux = bce_loss(d_out_aux, target_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() else: loss_d_aux = 0 pred_trg_main = pred_trg_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_d_main = bce_loss(d_out_main, target_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() optimizer.step() if cfg.TRAIN.MULTI_LEVEL: optimizer_d_aux.step() optimizer_d_main.step() if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth') torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth') torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth') if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: # ----------------------------------------------------------------# if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1: current_losses = { 'loss_seg_src_aux': loss_seg_src_aux, 'loss_seg_src_main': loss_seg_src_main, 'loss_adv_trg_aux': loss_adv_trg_aux, 'loss_adv_trg_main': loss_adv_trg_main, 'loss_d_aux': loss_d_aux, 'loss_d_main': loss_d_main } print_losses(current_losses, i_iter) log_losses_tensorboard(writer, current_losses, i_iter) draw_in_tensorboard(writer, images + mean_images, i_iter, pred_trg_main, num_classes, 'T') draw_in_tensorboard(writer, images_source + mean_images_source, i_iter, pred_src_main, num_classes, 'S') wandb.log({'loss': current_losses}, step=(i_iter + 1)) if i_iter % (cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE ) * 25 - 1: # for every 2500 iteration wandb.log( {'source': wandb.Image(torch.flip(images_source+mean_images_source, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), \ 'target': wandb.Image(torch.flip(images+mean_images, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), 'pesudo label': wandb.Image(np.asarray(colorize_mask(np.asarray(labels.cpu().data.numpy().transpose(1,2,0).reshape((512,1024)), dtype=np.uint8)).convert('RGB')) )}, step=(i_iter + 1))
def train_self_domain_swarp(model, trainloader, targetloader, cfg): ''' UDA training with advent ''' # Create the model and start the training. input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # SEGMNETATION NETWORK model.train() model.to(device) # Model clone model_runner = copy.deepcopy(model) model_runner.eval() model_runner.to(device) # conv3x3_tgt = get_conv_abstract(cfg) # conv3x3_tgt.train() # conv3x3_tgt.to(device) # d_main = get_fc_discriminator(num_classes=num_classes) # d_main.train() # d_main.to(device) tgt_dict_tot = {} cudnn.benchmark = True cudnn.enabled = True # OPTIMIZERS # params = list(model.parameters()) + list(conv3x3_tgt.parameters()) optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) cls_thresh = torch.ones(num_classes).type(torch.float32) # optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, # betas=(0.9, 0.99)) # for round in range(3): trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) source_label = 0 target_label = 1 tot_iter = len(targetloader) for i_iter in tqdm(range(tot_iter)): # reset optimizers optimizer.zero_grad() # optimizer_d_main.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) # adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # train on source _, batch = trainloader_iter.__next__() images_source, labels, _, _ = batch pred_src_main, _ = model(images_source.cuda(device)) pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) loss = cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main loss.backward() # adversarial training ot fool the discriminator _, batch = targetloader_iter.__next__() images, images_rev, _, _, name, name_next = batch pred_trg_main, feat_trg_main = model(images.cuda(device)) pred_trg_main = interp_target(pred_trg_main) with torch.no_grad(): pred_trg_main_run, feat_trg_main_run = model_runner(images.cuda(device)) pred_trg_main_run = interp_target(pred_trg_main_run) ##### Label generator for target ##### label_trg, cls_thresh = label_generator(pred_trg_main_run, cls_thresh, cfg, i_iter, tot_iter) ##### CE loss for trg # MRKLD + Ign Region loss_seg_trg_main = reg_loss_calc_ign(pred_trg_main, label_trg, device) loss_tgt_seg = cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_trg_main ##### Domain swarping #### feat_tgt_swarped, tgt_dict_tot, tgt_label = DomainSwarping(feat_trg_main, label_trg, tgt_dict_tot, device) ignore_mask = tgt_label == 255 feat_tgt_swarped = ~ignore_mask*feat_tgt_swarped + ignore_mask*feat_trg_main pred_tgt_swarped = model.classifier_(feat_tgt_swarped) pred_tgt_swarped = interp_target(pred_tgt_swarped) loss_seg_trg_swarped = reg_loss_calc_ign(pred_tgt_swarped, label_trg, device) loss_tgt_seg_swarped = cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_trg_swarped loss_tgt = loss_tgt_seg + loss_tgt_seg_swarped loss_tgt.backward() optimizer.step() current_losses = {'loss_seg_trg_main': loss_seg_trg_main, 'loss_seg_src_main': loss_seg_src_main, 'loss_seg_trg_swarped': loss_seg_trg_swarped } print_losses(current_losses, i_iter) if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth') torch.save(model_runner.state_dict(), snapshot_dir / f'model_{i_iter}_run.pth') if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: log_losses_tensorboard(writer, current_losses, i_iter) if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == 0: # draw_in_tensorboard_trg(writer, images, images_rev, label_trg, i_iter, pred_trg_main, pred_trg_main_rev, num_classes, 'T') draw_in_tensorboard(writer, images, label_trg, i_iter, pred_trg_main, pred_tgt_swarped, num_classes, 'T')