def train(cfg, local_rank, distributed): logger = logging.getLogger("SDCA.trainer") # create network device = torch.device(cfg.MODEL.DEVICE) feature_extractor = build_feature_extractor(cfg) feature_extractor.to(device) classifier = build_classifier(cfg) classifier.to(device) if local_rank == 0: print(classifier) # batch size: half for source and half for target batch_size = cfg.SOLVER.BATCH_SIZE // 2 if distributed: pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size())) batch_size = int(cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2 if not cfg.MODEL.FREEZE_BN: feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(feature_extractor) feature_extractor = torch.nn.parallel.DistributedDataParallel( feature_extractor, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg1 ) pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size())) classifier = torch.nn.parallel.DistributedDataParallel( classifier, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg2 ) torch.autograd.set_detect_anomaly(True) torch.distributed.barrier() # init optimizer optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_fea.zero_grad() optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_cls.zero_grad() # load checkpoint if cfg.resume: logger.info("Loading checkpoint from {}".format(cfg.resume)) checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu')) feature_weights = checkpoint['feature_extractor'] if distributed else strip_prefix_if_present( checkpoint['feature_extractor'], 'module.') feature_extractor.load_state_dict(feature_weights) classifier_weights = checkpoint['classifier'] if distributed else strip_prefix_if_present( checkpoint['classifier'], 'module.') classifier.load_state_dict(classifier_weights) # init data loader src_train_data = build_dataset(cfg, mode='train', is_source=True) tgt_train_data = build_dataset(cfg, mode='train', is_source=False) if distributed: src_train_sampler = torch.utils.data.distributed.DistributedSampler(src_train_data) tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(tgt_train_data) else: src_train_sampler = None tgt_train_sampler = None src_train_loader = torch.utils.data.DataLoader(src_train_data, batch_size=batch_size, shuffle=(src_train_sampler is None), num_workers=4, pin_memory=True, sampler=src_train_sampler, drop_last=True) tgt_train_loader = torch.utils.data.DataLoader(tgt_train_data, batch_size=batch_size, shuffle=(tgt_train_sampler is None), num_workers=4, pin_memory=True, sampler=tgt_train_sampler, drop_last=True) # init loss ce_criterion = nn.CrossEntropyLoss(ignore_index=255) pcl_criterion = PixelContrastiveLoss(cfg) # load semantic distributions logger.info(">>>>>>>>>>>>>>>> Load semantic distributions >>>>>>>>>>>>>>>>") _, backbone_name = cfg.MODEL.NAME.split('_') feature_num = 2048 if backbone_name.startswith('resnet') else 1024 feat_estimator = semantic_dist_estimator(feature_num=feature_num, cfg=cfg) if cfg.SOLVER.MULTI_LEVEL: out_estimator = semantic_dist_estimator(feature_num=cfg.MODEL.NUM_CLASSES, cfg=cfg) iteration = 0 start_training_time = time.time() end = time.time() save_to_disk = local_rank == 0 max_iters = cfg.SOLVER.MAX_ITER meters = MetricLogger(delimiter=" ") logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") feature_extractor.train() classifier.train() for i, ((src_input, src_label, src_name), (tgt_input, _, _)) in enumerate(zip(src_train_loader, tgt_train_loader)): data_time = time.time() - end current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters, power=cfg.SOLVER.LR_POWER) for index in range(len(optimizer_fea.param_groups)): optimizer_fea.param_groups[index]['lr'] = current_lr for index in range(len(optimizer_cls.param_groups)): optimizer_cls.param_groups[index]['lr'] = current_lr * 10 optimizer_fea.zero_grad() optimizer_cls.zero_grad() src_input = src_input.cuda(non_blocking=True) src_label = src_label.cuda(non_blocking=True).long() tgt_input = tgt_input.cuda(non_blocking=True) src_size = src_input.shape[-2:] src_feat = feature_extractor(src_input) src_out = classifier(src_feat) tgt_feat = feature_extractor(tgt_input) tgt_out = classifier(tgt_feat) tgt_out_softmax = F.softmax(tgt_out, dim=1) # supervision loss src_pred = F.interpolate(src_out, size=src_size, mode='bilinear', align_corners=True) if cfg.SOLVER.LAMBDA_LOV > 0: src_pred_softmax = F.softmax(src_pred, dim=1) loss_lov = lovasz_softmax(src_pred_softmax, src_label, ignore=255) loss_sup = ce_criterion(src_pred, src_label) + cfg.SOLVER.LAMBDA_LOV * loss_lov meters.update(loss_lov=loss_lov.item()) else: loss_sup = ce_criterion(src_pred, src_label) meters.update(loss_sup=loss_sup.item()) # source mask: downsample the ground-truth label B, A, Hs, Ws = src_feat.size() src_mask = F.interpolate(src_label.unsqueeze(0).float(), size=(Hs, Ws), mode='nearest').squeeze(0).long() src_mask = src_mask.contiguous().view(B * Hs * Ws, ) assert not src_mask.requires_grad # target mask: constant threshold _, _, Ht, Wt = tgt_feat.size() tgt_out_maxvalue, tgt_mask = torch.max(tgt_out_softmax, dim=1) for i in range(cfg.MODEL.NUM_CLASSES): tgt_mask[(tgt_out_maxvalue < cfg.SOLVER.DELTA) * (tgt_mask == i)] = 255 tgt_mask = tgt_mask.contiguous().view(B * Ht * Wt, ) assert not tgt_mask.requires_grad src_feat = src_feat.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, A) tgt_feat = tgt_feat.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, A) # update feature-level statistics feat_estimator.update(features=src_feat.detach(), labels=src_mask) # contrastive loss on both domains loss_feat = pcl_criterion(Mean=feat_estimator.Mean.detach(), CoVariance=feat_estimator.CoVariance.detach(), feat=src_feat, labels=src_mask) \ + pcl_criterion(Mean=feat_estimator.Mean.detach(), CoVariance=feat_estimator.CoVariance.detach(), feat=tgt_feat, labels=tgt_mask) meters.update(loss_feat=loss_feat.item()) if cfg.SOLVER.MULTI_LEVEL: src_out = src_out.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, cfg.MODEL.NUM_CLASSES) tgt_out = tgt_out.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, cfg.MODEL.NUM_CLASSES) # update output-level statistics out_estimator.update(features=src_out.detach(), labels=src_mask) # the proposed contrastive loss on prediction map loss_out = pcl_criterion(Mean=out_estimator.Mean.detach(), CoVariance=out_estimator.CoVariance.detach(), feat=src_out, labels=src_mask) \ + pcl_criterion(Mean=out_estimator.Mean.detach(), CoVariance=out_estimator.CoVariance.detach(), feat=tgt_out, labels=tgt_mask) meters.update(loss_out=loss_out.item()) loss = loss_sup \ + cfg.SOLVER.LAMBDA_FEAT * loss_feat \ + cfg.SOLVER.LAMBDA_OUT * loss_out else: loss = loss_sup + cfg.SOLVER.LAMBDA_FEAT * loss_feat loss.backward() optimizer_fea.step() optimizer_cls.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) iteration += 1 if iteration % 20 == 0 or iteration == max_iters: logger.info( meters.delimiter.join( [ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.02f} GB" ] ).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer_fea.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0 ) ) if (iteration == cfg.SOLVER.MAX_ITER or iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk: filename = os.path.join(cfg.OUTPUT_DIR, "model_iter{:06d}.pth".format(iteration)) torch.save({'iteration': iteration, 'feature_extractor': feature_extractor.state_dict(), 'classifier': classifier.state_dict(), 'optimizer_fea': optimizer_fea.state_dict(), 'optimizer_cls': optimizer_cls.state_dict(), }, filename) if iteration == cfg.SOLVER.MAX_ITER: break if iteration == cfg.SOLVER.STOP_ITER: break total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / cfg.SOLVER.STOP_ITER ) ) return feature_extractor, classifier
def train(cfg, local_rank, distributed): logger = logging.getLogger("FADA.trainer") logger.info("Start training") feature_extractor = build_feature_extractor(cfg, adv=True) device = torch.device(cfg.MODEL.DEVICE) feature_extractor.to(device) classifier = build_classifier(cfg) classifier.to(device) model_D = build_adversarial_discriminator(cfg) model_D.to(device) if local_rank == 0: print(feature_extractor) print(model_D) batch_size = cfg.SOLVER.BATCH_SIZE // 2 if distributed: pg1 = torch.distributed.new_group( range(torch.distributed.get_world_size())) batch_size = int( cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2 if not cfg.MODEL.FREEZE_BN: feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm( feature_extractor) feature_extractor = torch.nn.parallel.DistributedDataParallel( feature_extractor, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg1) pg2 = torch.distributed.new_group( range(torch.distributed.get_world_size())) classifier = torch.nn.parallel.DistributedDataParallel( classifier, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg2) pg3 = torch.distributed.new_group( range(torch.distributed.get_world_size())) model_D = torch.nn.parallel.DistributedDataParallel( model_D, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg3) torch.autograd.set_detect_anomaly(True) torch.distributed.barrier() optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_fea.zero_grad() optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_cls.zero_grad() optimizer_D = torch.optim.Adam(model_D.parameters(), lr=cfg.SOLVER.BASE_LR_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() output_dir = cfg.OUTPUT_DIR save_to_disk = local_rank == 0 start_epoch = 0 iteration = 0 if cfg.resume: logger.info("Loading checkpoint from {}".format(cfg.resume)) checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu')) model_weights = checkpoint[ 'feature_extractor'] if distributed else strip_prefix_if_present( checkpoint['feature_extractor'], 'module.') feature_extractor.load_state_dict(model_weights) classifier_weights = checkpoint[ 'classifier'] if distributed else strip_prefix_if_present( checkpoint['classifier'], 'module.') classifier.load_state_dict(classifier_weights) if "model_D" in checkpoint: logger.info("Loading model_D from {}".format(cfg.resume)) model_D_weights = checkpoint[ 'model_D'] if distributed else strip_prefix_if_present( checkpoint['model_D'], 'module.') model_D.load_state_dict(model_D_weights) # if "optimizer_fea" in checkpoint: # logger.info("Loading optimizer_fea from {}".format(cfg.resume)) # optimizer_fea.load_state_dict(checkpoint['optimizer_fea']) # if "optimizer_cls" in checkpoint: # logger.info("Loading optimizer_cls from {}".format(cfg.resume)) # optimizer_cls.load_state_dict(checkpoint['optimizer_cls']) # if "optimizer_D" in checkpoint: # logger.info("Loading optimizer_D from {}".format(cfg.resume)) # optimizer_D.load_state_dict(checkpoint['optimizer_D']) # if "iteration" in checkpoint: # iteration = checkpoint['iteration'] src_train_data = build_dataset(cfg, mode='train', is_source=True) tgt_train_data = build_dataset(cfg, mode='train', is_source=False) if distributed: src_train_sampler = torch.utils.data.distributed.DistributedSampler( src_train_data) tgt_train_sampler = torch.utils.data.distributed.DistributedSampler( tgt_train_data) else: src_train_sampler = None tgt_train_sampler = None src_train_loader = torch.utils.data.DataLoader( src_train_data, batch_size=batch_size, shuffle=(src_train_sampler is None), num_workers=4, pin_memory=True, sampler=src_train_sampler, drop_last=True) tgt_train_loader = torch.utils.data.DataLoader( tgt_train_data, batch_size=batch_size, shuffle=(tgt_train_sampler is None), num_workers=4, pin_memory=True, sampler=tgt_train_sampler, drop_last=True) criterion = torch.nn.CrossEntropyLoss(ignore_index=255) bce_loss = torch.nn.BCELoss(reduction='none') max_iters = cfg.SOLVER.MAX_ITER source_label = 0 target_label = 1 logger.info("Start training") meters = MetricLogger(delimiter=" ") feature_extractor.train() classifier.train() model_D.train() start_training_time = time.time() end = time.time() for i, ((src_input, src_label, src_name), (tgt_input, _, _)) in enumerate(zip(src_train_loader, tgt_train_loader)): # torch.distributed.barrier() data_time = time.time() - end current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters, power=cfg.SOLVER.LR_POWER) current_lr_D = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR_D, iteration, max_iters, power=cfg.SOLVER.LR_POWER) for index in range(len(optimizer_fea.param_groups)): optimizer_fea.param_groups[index]['lr'] = current_lr for index in range(len(optimizer_cls.param_groups)): optimizer_cls.param_groups[index]['lr'] = current_lr * 10 for index in range(len(optimizer_D.param_groups)): optimizer_D.param_groups[index]['lr'] = current_lr_D # torch.distributed.barrier() optimizer_fea.zero_grad() optimizer_cls.zero_grad() optimizer_D.zero_grad() src_input = src_input.cuda(non_blocking=True) src_label = src_label.cuda(non_blocking=True).long() tgt_input = tgt_input.cuda(non_blocking=True) src_size = src_input.shape[-2:] tgt_size = tgt_input.shape[-2:] inp = torch.cat( [src_input, F.interpolate(tgt_input.detach(), src_size)]) # try: src_fea = feature_extractor(inp)[:batch_size] src_pred = classifier(src_fea, src_size) temperature = 1.8 src_pred = src_pred.div(temperature) loss_seg = criterion(src_pred, src_label) loss_seg.backward() # torch.distributed.barrier() # generate soft labels src_soft_label = F.softmax(src_pred, dim=1).detach() src_soft_label[src_soft_label > 0.9] = 0.9 tgt_fea = feature_extractor(tgt_input) tgt_pred = classifier(tgt_fea, tgt_size) tgt_pred = tgt_pred.div(temperature) tgt_soft_label = F.softmax(tgt_pred, dim=1) tgt_soft_label = tgt_soft_label.detach() tgt_soft_label[tgt_soft_label > 0.9] = 0.9 tgt_D_pred = model_D(tgt_fea, tgt_size) loss_adv_tgt = 0.001 * soft_label_cross_entropy( tgt_D_pred, torch.cat( (tgt_soft_label, torch.zeros_like(tgt_soft_label)), dim=1)) loss_adv_tgt.backward() optimizer_fea.step() optimizer_cls.step() optimizer_D.zero_grad() # torch.distributed.barrier() src_D_pred = model_D(src_fea.detach(), src_size) loss_D_src = 0.5 * soft_label_cross_entropy( src_D_pred, torch.cat( (src_soft_label, torch.zeros_like(src_soft_label)), dim=1)) loss_D_src.backward() tgt_D_pred = model_D(tgt_fea.detach(), tgt_size) loss_D_tgt = 0.5 * soft_label_cross_entropy( tgt_D_pred, torch.cat( (torch.zeros_like(tgt_soft_label), tgt_soft_label), dim=1)) loss_D_tgt.backward() # torch.distributed.barrier() optimizer_D.step() meters.update(loss_seg=loss_seg.item()) meters.update(loss_adv_tgt=loss_adv_tgt.item()) meters.update(loss_D=(loss_D_src.item() + loss_D_tgt.item())) meters.update(loss_D_src=loss_D_src.item()) meters.update(loss_D_tgt=loss_D_tgt.item()) iteration = iteration + 1 n = src_input.size(0) batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iters: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer_fea.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) if (iteration == cfg.SOLVER.MAX_ITER or iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk: filename = os.path.join(output_dir, "model_iter{:06d}.pth".format(iteration)) torch.save( { 'iteration': iteration, 'feature_extractor': feature_extractor.state_dict(), 'classifier': classifier.state_dict(), 'model_D': model_D.state_dict(), 'optimizer_fea': optimizer_fea.state_dict(), 'optimizer_cls': optimizer_cls.state_dict(), 'optimizer_D': optimizer_D.state_dict() }, filename) if iteration == cfg.SOLVER.MAX_ITER: break if iteration == cfg.SOLVER.STOP_ITER: break total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (cfg.SOLVER.MAX_ITER))) return feature_extractor, classifier
def train(cfg, local_rank, distributed): logger = logging.getLogger("SelfSupervised.trainer") logger.info("Start training") feature_extractor = build_feature_extractor(cfg) device = torch.device(cfg.MODEL.DEVICE) feature_extractor.to(device) classifier = build_classifier(cfg) classifier.to(device) if local_rank == 0: print(feature_extractor) print(classifier) batch_size = cfg.SOLVER.BATCH_SIZE if distributed: pg1 = torch.distributed.new_group( range(torch.distributed.get_world_size())) batch_size = int(cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) if not cfg.MODEL.FREEZE_BN: feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm( feature_extractor) feature_extractor = torch.nn.parallel.DistributedDataParallel( feature_extractor, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg1) pg2 = torch.distributed.new_group( range(torch.distributed.get_world_size())) classifier = torch.nn.parallel.DistributedDataParallel( classifier, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg2) torch.autograd.set_detect_anomaly(True) torch.distributed.barrier() optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_fea.zero_grad() optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_cls.zero_grad() output_dir = cfg.OUTPUT_DIR save_to_disk = local_rank == 0 iteration = 0 if cfg.resume: logger.info("Loading checkpoint from {}".format(cfg.resume)) checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu')) model_weights = checkpoint[ 'feature_extractor'] if distributed else strip_prefix_if_present( checkpoint['feature_extractor'], 'module.') feature_extractor.load_state_dict(model_weights) classifier_weights = checkpoint[ 'classifier'] if distributed else strip_prefix_if_present( checkpoint['classifier'], 'module.') classifier.load_state_dict(classifier_weights) src_train_data = build_dataset(cfg, mode='train', is_source=True) if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( src_train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(src_train_data, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=4, pin_memory=True, sampler=train_sampler, drop_last=True) ce_criterion = torch.nn.CrossEntropyLoss(ignore_index=255) max_iters = cfg.SOLVER.MAX_ITER logger.info("Start training") meters = MetricLogger(delimiter=" ") feature_extractor.train() classifier.train() start_training_time = time.time() end = time.time() for i, (src_input, src_label, _) in enumerate(train_loader): data_time = time.time() - end current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters, power=cfg.SOLVER.LR_POWER) for index in range(len(optimizer_fea.param_groups)): optimizer_fea.param_groups[index]['lr'] = current_lr for index in range(len(optimizer_cls.param_groups)): optimizer_cls.param_groups[index]['lr'] = current_lr * 10 optimizer_fea.zero_grad() optimizer_cls.zero_grad() src_input = src_input.cuda(non_blocking=True) src_label = src_label.cuda(non_blocking=True).long() size = src_label.shape[-2:] pred = classifier(feature_extractor(src_input), size) loss = ce_criterion(pred, src_label) loss.backward() optimizer_fea.step() optimizer_cls.step() meters.update(loss_seg=loss.item()) iteration += 1 batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iters: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.2f} GB", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer_fea.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0, )) if (iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0 or iteration == max_iters) and save_to_disk: filename = os.path.join(output_dir, "model_iter{:06d}.pth".format(iteration)) torch.save( { 'iteration': iteration, 'feature_extractor': feature_extractor.state_dict(), 'classifier': classifier.state_dict(), 'optimizer_fea': optimizer_fea.state_dict(), 'optimizer_cls': optimizer_cls.state_dict() }, filename) if iteration == cfg.SOLVER.MAX_ITER: break if iteration == cfg.SOLVER.STOP_ITER: break total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / cfg.SOLVER.STOP_ITER)) return feature_extractor, classifier
def train(cfg, local_rank, distributed): logger = logging.getLogger("BCDM.trainer") logger.info("Start training") feature_extractor = build_feature_extractor(cfg) device = torch.device(cfg.MODEL.DEVICE) feature_extractor.to(device) classifier = build_classifier(cfg) classifier.to(device) classifier_2 = build_classifier(cfg) classifier_2.to(device) if local_rank == 0: print(feature_extractor) batch_size = cfg.SOLVER.BATCH_SIZE // 2 if distributed: pg1 = torch.distributed.new_group( range(torch.distributed.get_world_size())) batch_size = int( cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2 if not cfg.MODEL.FREEZE_BN: feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm( feature_extractor) feature_extractor = torch.nn.parallel.DistributedDataParallel( feature_extractor, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg1) pg2 = torch.distributed.new_group( range(torch.distributed.get_world_size())) classifier = torch.nn.parallel.DistributedDataParallel( classifier, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg2) pg3 = torch.distributed.new_group( range(torch.distributed.get_world_size())) classifier_2 = torch.nn.parallel.DistributedDataParallel( classifier_2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, process_group=pg3) torch.autograd.set_detect_anomaly(True) torch.distributed.barrier() optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_fea.zero_grad() optimizer_cls = torch.optim.SGD(list(classifier.parameters()) + list(classifier_2.parameters()), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) optimizer_cls.zero_grad() output_dir = cfg.OUTPUT_DIR save_to_disk = local_rank == 0 iteration = 0 if cfg.resume: logger.info("Loading checkpoint from {}".format(cfg.resume)) checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu')) model_weights = checkpoint[ 'feature_extractor'] if distributed else strip_prefix_if_present( checkpoint['feature_extractor'], 'module.') feature_extractor.load_state_dict(model_weights) classifier_weights = checkpoint[ 'classifier'] if distributed else strip_prefix_if_present( checkpoint['classifier'], 'module.') classifier.load_state_dict(classifier_weights) classifier_2_weights = checkpoint[ 'classifier'] if distributed else strip_prefix_if_present( checkpoint['classifier_2'], 'module.') classifier_2.load_state_dict(classifier_2_weights) # if "optimizer_fea" in checkpoint: # logger.info("Loading optimizer_fea from {}".format(cfg.resume)) # optimizer_fea.load_state_dict(checkpoint['optimizer_fea']) # if "optimizer_cls" in checkpoint: # logger.info("Loading optimizer_cls from {}".format(cfg.resume)) # optimizer_cls.load_state_dict(checkpoint['optimizer_cls']) # if "iteration" in checkpoint: # iteration = checkpoint['iteration'] src_train_data = build_dataset(cfg, mode='train', is_source=True) tgt_train_data = build_dataset(cfg, mode='train', is_source=False) if distributed: src_train_sampler = torch.utils.data.distributed.DistributedSampler( src_train_data) tgt_train_sampler = torch.utils.data.distributed.DistributedSampler( tgt_train_data) else: src_train_sampler = None tgt_train_sampler = None src_train_loader = torch.utils.data.DataLoader( src_train_data, batch_size=batch_size, shuffle=(src_train_sampler is None), num_workers=4, pin_memory=True, sampler=src_train_sampler, drop_last=True) tgt_train_loader = torch.utils.data.DataLoader( tgt_train_data, batch_size=batch_size, shuffle=(tgt_train_sampler is None), num_workers=4, pin_memory=True, sampler=tgt_train_sampler, drop_last=True) criterion = torch.nn.CrossEntropyLoss(ignore_index=255) max_iters = cfg.SOLVER.MAX_ITER logger.info("Start training") meters = MetricLogger(delimiter=" ") feature_extractor.train() classifier.train() classifier_2.train() start_training_time = time.time() end = time.time() for i, ((src_input, src_label, src_name), (tgt_input, _, _)) in enumerate(zip(src_train_loader, tgt_train_loader)): data_time = time.time() - end current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters, power=cfg.SOLVER.LR_POWER) for index in range(len(optimizer_fea.param_groups)): optimizer_fea.param_groups[index]['lr'] = current_lr for index in range(len(optimizer_cls.param_groups)): optimizer_cls.param_groups[index]['lr'] = current_lr * 10 # Step A: train on source (CE loss) and target (Ent loss) optimizer_fea.zero_grad() optimizer_cls.zero_grad() src_input = src_input.cuda(non_blocking=True) src_label = src_label.cuda(non_blocking=True).long() tgt_input = tgt_input.cuda(non_blocking=True) src_size = src_input.shape[-2:] tgt_size = tgt_input.shape[-2:] src_fea = feature_extractor(src_input) src_pred = classifier(src_fea, src_size) src_pred_2 = classifier_2(src_fea, src_size) temperature = 1.8 src_pred = src_pred.div(temperature) src_pred_2 = src_pred_2.div(temperature) # source segmentation loss loss_seg = criterion(src_pred, src_label) + criterion( src_pred_2, src_label) tgt_fea = feature_extractor(tgt_input) tgt_pred = classifier(tgt_fea, tgt_size) tgt_pred_2 = classifier_2(tgt_fea, tgt_size) tgt_pred = F.softmax(tgt_pred) tgt_pred_2 = F.softmax(tgt_pred_2) loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2) total_loss = loss_seg + cfg.SOLVER.ENT_LOSS * loss_ent total_loss.backward() # torch.distributed.barrier() optimizer_fea.step() optimizer_cls.step() # Step B: train bi-classifier to maximize loss_cdd optimizer_fea.zero_grad() optimizer_cls.zero_grad() src_fea = feature_extractor(src_input) src_pred = classifier(src_fea, src_size) src_pred_2 = classifier_2(src_fea, src_size) temperature = 1.8 src_pred = src_pred.div(temperature) src_pred_2 = src_pred_2.div(temperature) loss_seg = criterion(src_pred, src_label) + criterion( src_pred_2, src_label) tgt_fea = feature_extractor(tgt_input) tgt_pred = classifier(tgt_fea, tgt_size) tgt_pred_2 = classifier_2(tgt_fea, tgt_size) tgt_pred = F.softmax(tgt_pred) tgt_pred_2 = F.softmax(tgt_pred_2) loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2) loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2) total_loss = loss_seg - cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent total_loss.backward() optimizer_cls.step() # Step C: train feature extractor to min loss_cdd for k in range(cfg.SOLVER.NUM_K): optimizer_fea.zero_grad() optimizer_cls.zero_grad() tgt_fea = feature_extractor(tgt_input) tgt_pred = classifier(tgt_fea, tgt_size) tgt_pred_2 = classifier_2(tgt_fea, tgt_size) tgt_pred = F.softmax(tgt_pred) tgt_pred_2 = F.softmax(tgt_pred_2) loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2) loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2) total_loss = cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent total_loss.backward() optimizer_fea.zero_grad() meters.update(loss_seg=loss_seg.item()) meters.update(loss_cdd=loss_cdd.item()) meters.update(loss_ent=loss_ent.item()) iteration = iteration + 1 n = src_input.size(0) batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iters: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer_fea.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) if (iteration == cfg.SOLVER.MAX_ITER or iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk: filename = os.path.join(output_dir, "model_iter{:06d}.pth".format(iteration)) torch.save( { 'iteration': iteration, 'feature_extractor': feature_extractor.state_dict(), 'classifier': classifier.state_dict(), 'classifier_2': classifier_2.state_dict(), 'optimizer_fea': optimizer_fea.state_dict(), 'optimizer_cls': optimizer_cls.state_dict() }, filename) if iteration == cfg.SOLVER.MAX_ITER: break if iteration == cfg.SOLVER.STOP_ITER: break total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / cfg.SOLVER.MAX_ITER)) return feature_extractor, classifier, classifier_2