def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cls_weights, gpu, weights_init, num_cls, lsgan, max_iter, lambda_d, lambda_g, train_discrim_only, weights_discrim, crop_size, weights_shared, discrim_feat, half_crop, batch, model, data_flag, resize, with_mmd_loss, small): # So data is sampled in consistent way np.random.seed(1336) torch.manual_seed(1336) logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format( model, dataset[0], dataset[1], lr, lambda_d, lambda_g) if weights_shared: logdir += '_weights_shared' else: logdir += '_weights_unshared' if discrim_feat: logdir += '_discrim_feat' else: logdir += '_discrim_score' logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M') writer = SummaryWriter(log_dir=logdir) os.environ['CUDA_VISIBLE_DEVICES'] = gpu config_logging() print('Train Discrim Only', train_discrim_only) if model == 'fcn8s': net = get_model(model, num_cls=num_cls, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat) else: net = get_model(model, num_cls=num_cls, finetune=True, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat) net.cuda() str_ids = gpu.split(',') gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: gpu_ids.append(id) # set gpu ids if len(gpu_ids) > 0: torch.cuda.set_device(gpu_ids[0]) assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) if weights_shared: net_src = net # shared weights else: net_src = get_model(model, num_cls=num_cls, finetune=True, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat) net_src.eval() # initialize Discrminator odim = 1 if lsgan else 2 idim = num_cls if not discrim_feat else 4096 print('Discrim_feat', discrim_feat, idim) print('Discriminator init weights: ', weights_discrim) discriminator = Discriminator(input_dim=idim, output_dim=odim, pretrained=not (weights_discrim == None), weights_init=weights_discrim).cuda() discriminator.to(gpu_ids[0]) discriminator = torch.nn.DataParallel(discriminator, gpu_ids) loader = AddaDataLoader(net.module.transform, dataset, datadir, downscale, resize=resize, crop_size=crop_size, half_crop=half_crop, batch_size=batch, shuffle=True, num_workers=16, src_data_flag=data_flag, small=small) print('dataset', dataset) # Class weighted loss? if cls_weights is not None: weights = np.loadtxt(cls_weights) else: weights = None # setup optimizers opt_dis = torch.optim.SGD(discriminator.module.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) opt_rep = torch.optim.SGD(net.module.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) iteration = 0 num_update_g = 0 last_update_g = -1 losses_super_s = deque(maxlen=100) losses_super_t = deque(maxlen=100) losses_dis = deque(maxlen=100) losses_rep = deque(maxlen=100) accuracies_dom = deque(maxlen=100) intersections = np.zeros([100, num_cls]) iu_deque = deque(maxlen=100) unions = np.zeros([100, num_cls]) accuracy = deque(maxlen=100) print('Max Iter:', max_iter) net.train() discriminator.train() loader.loader_src.dataset.__getitem__(0, debug=True) loader.loader_tgt.dataset.__getitem__(0, debug=True) while iteration < max_iter: for im_s, im_t, label_s, label_t in loader: if iteration == 0: print("IM S: {}".format(im_s.size())) print("Label S: {}".format(label_s.size())) print("IM T: {}".format(im_t.size())) print("Label T: {}".format(label_t.size())) if iteration > max_iter: break info_str = 'Iteration {}: '.format(iteration) if not check_label(label_s, num_cls): continue ########################### # 1. Setup Data Variables # ########################### im_s = make_variable(im_s, requires_grad=False) label_s = make_variable(label_s, requires_grad=False) im_t = make_variable(im_t, requires_grad=False) label_t = make_variable(label_t, requires_grad=False) ############################# # 2. Optimize Discriminator # ############################# # zero gradients for optimizer opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_s, feat_s = net_src(im_s) score_s = Variable(score_s.data, requires_grad=False) f_s = Variable(feat_s.data, requires_grad=False) else: score_s = Variable(net_src(im_s).data, requires_grad=False) f_s = score_s dis_score_s = discriminator(f_s) if discrim_feat: score_t, feat_t = net(im_t) score_t = Variable(score_t.data, requires_grad=False) f_t = Variable(feat_t.data, requires_grad=False) else: score_t = Variable(net(im_t).data, requires_grad=False) f_t = score_t dis_score_t = discriminator(f_t) dis_pred_concat = torch.cat((dis_score_s, dis_score_t)) # prepare real and fake labels batch_t, _, h, w = dis_score_t.size() batch_s, _, _, _ = dis_score_s.size() dis_label_concat = make_variable(torch.cat([ torch.ones(batch_s, h, w).long(), torch.zeros(batch_t, h, w).long() ]), requires_grad=False) # compute loss for discriminator loss_dis = supervised_loss(dis_pred_concat, dis_label_concat) (lambda_d * loss_dis).backward() losses_dis.append(loss_dis.item()) # optimize discriminator opt_dis.step() # compute discriminator acc pred_dis = torch.squeeze(dis_pred_concat.max(1)[1]) dom_acc = (pred_dis == dis_label_concat).float().mean().item() accuracies_dom.append(dom_acc * 100.) # add discriminator info to log info_str += " domacc:{:0.1f} D:{:.3f}".format( np.mean(accuracies_dom), np.mean(losses_dis)) writer.add_scalar('loss/discriminator', np.mean(losses_dis), iteration) writer.add_scalar('acc/discriminator', np.mean(accuracies_dom), iteration) ########################### # Optimize Target Network # ########################### np.mean(accuracies_dom) > dom_acc_thresh dom_acc_thresh = 60 if train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh: os.makedirs(output, exist_ok=True) torch.save( discriminator.module.state_dict(), '{}/discriminator_abv60.pth'.format(output, iteration)) break if not train_discrim_only and np.mean( accuracies_dom) > dom_acc_thresh: last_update_g = iteration num_update_g += 1 if num_update_g % 1 == 0: print( 'Updating G with adversarial loss ({:d} times)'.format( num_update_g)) # zero out optimizer gradients opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_t, feat_t = net(im_t) score_t = Variable(score_t.data, requires_grad=False) f_t = feat_t else: score_t = net(im_t) f_t = score_t # score_t = net(im_t) dis_score_t = discriminator(f_t) # create fake label batch, _, h, w = dis_score_t.size() target_dom_fake_t = make_variable(torch.ones(batch, h, w).long(), requires_grad=False) # compute loss for target net loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t) (lambda_g * loss_gan_t).backward() losses_rep.append(loss_gan_t.item()) writer.add_scalar('loss/generator', np.mean(losses_rep), iteration) # optimize target net opt_rep.step() # log net update info info_str += ' G:{:.3f}'.format(np.mean(losses_rep)) if (not train_discrim_only) and weights_shared and np.mean( accuracies_dom) > dom_acc_thresh: print('Updating G using source supervised loss.') # zero out optimizer gradients opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_s, feat_s = net(im_s) else: score_s = net(im_s) loss_supervised_s = supervised_loss(score_s, label_s, weights=weights) if with_mmd_loss: print("Updating G using discrepancy loss") lambda_discrepancy = 0.1 loss_mmd = mmd_loss(feat_s, feat_t) * 0.5 + mmd_loss( score_s, score_t) * 0.5 loss_supervised_s += lambda_discrepancy * loss_mmd loss_supervised_s.backward() losses_super_s.append(loss_supervised_s.item()) info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s)) writer.add_scalar('loss/supervised/source', np.mean(losses_super_s), iteration) # optimize target net opt_rep.step() # compute supervised losses for target -- monitoring only!!!no backward() loss_supervised_t = supervised_loss(score_t, label_t, weights=weights) losses_super_t.append(loss_supervised_t.item()) info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t)) writer.add_scalar('loss/supervised/target', np.mean(losses_super_t), iteration) ########################### # Log and compute metrics # ########################### if iteration % 10 == 0 and iteration > 0: # compute metrics intersection, union, acc = seg_accuracy( score_t, label_t.data, num_cls) intersections = np.vstack( [intersections[1:, :], intersection[np.newaxis, :]]) unions = np.vstack([unions[1:, :], union[np.newaxis, :]]) accuracy.append(acc.item() * 100) acc = np.mean(accuracy) mIoU = np.mean( np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100 iu = (intersection / union) * 10000 iu_deque.append(np.nanmean(iu)) info_str += ' acc:{:0.2f} mIoU:{:0.2f}'.format( acc, np.mean(iu_deque)) writer.add_scalar('metrics/acc', np.mean(accuracy), iteration) writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration) logging.info(info_str) iteration += 1 ################ # Save outputs # ################ # every 500 iters save current model if iteration % 500 == 0: os.makedirs(output, exist_ok=True) if not train_discrim_only: torch.save(net.module.state_dict(), '{}/net-itercurr.pth'.format(output)) torch.save(discriminator.module.state_dict(), '{}/discriminator-itercurr.pth'.format(output)) # save labeled snapshots if iteration % snapshot == 0: os.makedirs(output, exist_ok=True) if not train_discrim_only: torch.save(net.module.state_dict(), '{}/net-iter{}.pth'.format(output, iteration)) torch.save( discriminator.module.state_dict(), '{}/discriminator-iter{}.pth'.format(output, iteration)) if iteration - last_update_g >= 3 * len(loader): print('No suitable discriminator found -- returning.') torch.save(net.module.state_dict(), '{}/net-iter{}.pth'.format(output, iteration)) iteration = max_iter # make sure outside loop breaks break writer.close()
def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cls_weights, weights_init, num_cls, lsgan, max_iter, lambda_d, lambda_g, train_discrim_only, weights_discrim, crop_size, weights_shared, discrim_feat, half_crop, batch, model, targetsup): targetSup = 1 # So data is sampled in consistent way np.random.seed(1337) torch.manual_seed(1337) logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format( model, dataset[0], dataset[1], lr, lambda_d, lambda_g) if weights_shared: logdir += '_weightshared' else: logdir += '_weightsunshared' if discrim_feat: logdir += '_discrimfeat' else: logdir += '_discrimscore' logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M') writer = SummaryWriter(logdir) config_logging() print('Train Discrim Only', train_discrim_only) net = get_model(model, num_cls=num_cls, output_last_ft=discrim_feat) net.load_state_dict(torch.load(weights_init)) if weights_shared: net_src = net # shared weights else: net_src = get_model(model, num_cls=num_cls, output_last_ft=discrim_feat) new_src.load_state_dict(torch.load(weights_init)) net_src.eval() print("GOT MODEL") odim = 1 if lsgan else 2 idim = num_cls if not discrim_feat else 4096 print('discrim_feat', discrim_feat, idim) print('discriminator init weights: ', weights_discrim) if torch.cuda.is_available(): discriminator = Discriminator(input_dim=idim, output_dim=odim, pretrained=not (weights_discrim == None), weights_init=weights_discrim).cuda() else: discriminator = Discriminator(input_dim=idim, output_dim=odim, pretrained=not (weights_discrim == None), weights_init=weights_discrim) loader = AddaDataLoader(None, dataset, datadir, downscale, crop_size=crop_size, half_crop=half_crop, batch_size=batch, shuffle=True, num_workers=2) print('dataset', dataset) # Class weighted loss? if cls_weights is not None: weights = np.loadtxt(cls_weights) else: weights = None # setup optimizers opt_dis = torch.optim.SGD(discriminator.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) opt_rep = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) iteration = 0 num_update_g = 0 last_update_g = -1 losses_super_s = deque(maxlen=100) losses_super_t = deque(maxlen=100) losses_dis = deque(maxlen=100) losses_rep = deque(maxlen=100) accuracies_dom = deque(maxlen=100) intersections = np.zeros([100, num_cls]) unions = np.zeros([100, num_cls]) accuracy = deque(maxlen=100) print('max iter:', max_iter) net.train() discriminator.train() IoU_s = deque(maxlen=100) IoU_t = deque(maxlen=100) Recall_s = deque(maxlen=100) Recall_t = deque(maxlen=100) while iteration < max_iter: for im_s, im_t, label_s, label_t in loader: if iteration > max_iter: break info_str = 'Iteration {}: '.format(iteration) if not check_label(label_s, num_cls): continue ########################### # 1. Setup Data Variables # ########################### im_s = make_variable(im_s, requires_grad=False) label_s = make_variable(label_s, requires_grad=False) im_t = make_variable(im_t, requires_grad=False) label_t = make_variable(label_t, requires_grad=False) ############################# # 2. Optimize Discriminator # ############################# # zero gradients for optimizer opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_s, feat_s = net_src(im_s) score_s = Variable(score_s.data, requires_grad=False) f_s = Variable(feat_s.data, requires_grad=False) else: score_s = Variable(net_src(im_s).data, requires_grad=False) f_s = score_s dis_score_s = discriminator(f_s) if discrim_feat: score_t, feat_t = net(im_t) score_t = Variable(score_t.data, requires_grad=False) f_t = Variable(feat_t.data, requires_grad=False) else: score_t = Variable(net(im_t).data, requires_grad=False) f_t = score_t dis_score_t = discriminator(f_t) dis_pred_concat = torch.cat((dis_score_s, dis_score_t)) # prepare real and fake labels batch_t, _, h, w = dis_score_t.size() batch_s, _, _, _ = dis_score_s.size() dis_label_concat = make_variable(torch.cat([ torch.ones(batch_s, h, w).long(), torch.zeros(batch_t, h, w).long() ]), requires_grad=False) # compute loss for discriminator loss_dis = supervised_loss(dis_pred_concat, dis_label_concat) (lambda_d * loss_dis).backward() losses_dis.append(loss_dis.item()) # optimize discriminator opt_dis.step() # compute discriminator acc pred_dis = torch.squeeze(dis_pred_concat.max(1)[1]) dom_acc = (pred_dis == dis_label_concat).float().mean().item() accuracies_dom.append(dom_acc * 100.) # add discriminator info to log info_str += " domacc:{:0.1f} D:{:.3f}".format( np.mean(accuracies_dom), np.mean(losses_dis)) writer.add_scalar('loss/discriminator', np.mean(losses_dis), iteration) writer.add_scalar('acc/discriminator', np.mean(accuracies_dom), iteration) ########################### # Optimize Target Network # ########################### dom_acc_thresh = 55 if not train_discrim_only and np.mean( accuracies_dom) > dom_acc_thresh: last_update_g = iteration num_update_g += 1 if num_update_g % 1 == 0: print( 'Updating G with adversarial loss ({:d} times)'.format( num_update_g)) # zero out optimizer gradients opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_t, feat_t = net(im_t) score_t = Variable(score_t.data, requires_grad=False) f_t = feat_t else: score_t = net(im_t) f_t = score_t #score_t = net(im_t) dis_score_t = discriminator(f_t) # create fake label batch, _, h, w = dis_score_t.size() target_dom_fake_t = make_variable(torch.ones(batch, h, w).long(), requires_grad=False) # compute loss for target net loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t) (lambda_g * loss_gan_t).backward() losses_rep.append(loss_gan_t.item()) writer.add_scalar('loss/generator', np.mean(losses_rep), iteration) # optimize target net opt_rep.step() # log net update info info_str += ' G:{:.3f}'.format(np.mean(losses_rep)) if (not train_discrim_only) and weights_shared and ( np.mean(accuracies_dom) > dom_acc_thresh): print('Updating G using source supervised loss.') # zero out optimizer gradients opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_s, _ = net(im_s) score_t, _ = net(im_t) else: score_s = net(im_s) score_t = net(im_t) loss_supervised_s = supervised_loss(score_s, label_s, weights=weights) loss_supervised_t = supervised_loss(score_t, label_t, weights=weights) loss_supervised = loss_supervised_s if targetSup: loss_supervised += loss_supervised_t loss_supervised.backward() losses_super_s.append(loss_supervised_s.item()) info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s)) writer.add_scalar('loss/supervised/source', np.mean(losses_super_s), iteration) losses_super_t.append(loss_supervised_t.item()) info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t)) writer.add_scalar('loss/supervised/target', np.mean(losses_super_t), iteration) # optimize target net opt_rep.step() ########################### # Log and compute metrics # ########################### if iteration % 10 == 0 and iteration > 0: # compute metrics intersection, union, acc = seg_accuracy( score_t, label_t.data, num_cls) iou_s = IoU(score_s, label_s) iou_t = IoU(score_t, label_t) rc_s = recall(score_s, label_s) rc_t = recall(score_t, label_t) IoU_s.append(iou_s.item()) IoU_t.append(iou_t.item()) Recall_s.append(rc_s.item()) Recall_t.append(rc_t.item()) intersections = np.vstack( [intersections[1:, :], intersection[np.newaxis, :]]) unions = np.vstack([unions[1:, :], union[np.newaxis, :]]) accuracy.append(acc.item() * 100) acc = np.mean(accuracy) mIoU = np.mean( np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100 info_str += ' IoU:{:0.2f} Recall:{:0.2f}'.format(iou_s, rc_s) # writer.add_scalar('metrics/acc', np.mean(accuracy), iteration) # writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration) # writer.add_scalar('metrics/RealIoU_Source', np.mean(IoU_s)) # writer.add_scalar('metrics/RealIoU_Target', np.mean(IoU_t)) # writer.add_scalar('metrics/RealRecall_Source', np.mean(Recall_s)) # writer.add_scalar('metrics/RealRecall_Target', np.mean(Recall_t)) logging.info(info_str) print(info_str) im_s = Image.fromarray( np.uint8( norm(im_s[0]).permute(1, 2, 0).cpu().data.numpy() * 255)) im_t = Image.fromarray( np.uint8( norm(im_t[0]).permute(1, 2, 0).cpu().data.numpy() * 255)) label_s = Image.fromarray( np.uint8(label_s[0].cpu().data.numpy() * 255)) label_t = Image.fromarray( np.uint8(label_t[0].cpu().data.numpy() * 255)) score_s = Image.fromarray( np.uint8(mxAxis(score_s[0]).cpu().data.numpy() * 255)) score_t = Image.fromarray( np.uint8(mxAxis(score_t[0]).cpu().data.numpy() * 255)) im_s.save(output + "/im_s.png") im_t.save(output + "/im_t.png") label_s.save(output + "/label_s.png") label_t.save(output + "/label_t.png") score_s.save(output + "/score_s.png") score_t.save(output + "/score_t.png") iteration += 1 ################ # Save outputs # ################ # every 500 iters save current model if iteration % 500 == 0: os.makedirs(output, exist_ok=True) if not train_discrim_only: torch.save(net.state_dict(), '{}/net-itercurr.pth'.format(output)) torch.save(discriminator.state_dict(), '{}/discriminator-itercurr.pth'.format(output)) # save labeled snapshots if iteration % snapshot == 0: os.makedirs(output, exist_ok=True) if not train_discrim_only: torch.save(net.state_dict(), '{}/net-iter{}.pth'.format(output, iteration)) torch.save( discriminator.state_dict(), '{}/discriminator-iter{}.pth'.format(output, iteration)) if iteration - last_update_g >= len(loader): print('No suitable discriminator found -- returning.') # import pdb;pdb.set_trace() # torch.save(net.state_dict(),'{}/net-iter{}.pth'.format(output, iteration)) # iteration = max_iter # make sure outside loop breaks # break writer.close()