def main(): print(config) cudnn.enabled = True torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) np.random.seed(random_seed) random.seed(random_seed) torch.backends.cudnn.deterministic = True if pretraining == 'COCO': # depending the pretraining, normalize with bgr or rgb from utils.transformsgpu import normalize_bgr as normalize else: from utils.transformsgpu import normalize_rgb as normalize batch_size_unlabeled = int( batch_size / 2) # because of augmentation anchoring, 2 augmentations per sample batch_size_labeled = int(batch_size * 1) assert batch_size_unlabeled >= 2, "batch size should be higher than 2" assert batch_size_labeled >= 2, "batch size should be higher than 2" RAMP_UP_ITERS = 2000 # iterations until contrastive and self-training are taken into account # DATASETS / LOADERS if dataset == 'pascal_voc': data_loader = get_loader(dataset) data_path = get_data_path(dataset) train_dataset = data_loader(data_path, crop_size=input_size, scale=False, mirror=False, pretraining=pretraining) elif dataset == 'cityscapes': data_loader = get_loader('cityscapes') data_path = get_data_path('cityscapes') if deeplabv2: data_aug = Compose([RandomCrop_city(input_size)]) else: # for deeplabv3 original resolution data_aug = Compose([RandomCrop_city_highres(input_size)]) train_dataset = data_loader(data_path, is_transform=True, augmentations=data_aug, img_size=input_size, pretraining=pretraining) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) partial_size = labeled_samples print('Training on number of samples:', partial_size) # class weighting taken unlabeled data into acount in an incremental fashion. class_weights_curr = ClassBalancing( labeled_iters=int(labeled_samples / batch_size_labeled), unlabeled_iters=int( (train_dataset_size - labeled_samples) / batch_size_unlabeled), n_classes=num_classes) # Memory Bank feature_memory = FeatureMemory(num_samples=labeled_samples, dataset=dataset, memory_per_class=256, feature_size=256, n_classes=num_classes) # select the partition if split_id is not None: train_ids = pickle.load(open(split_id, 'rb')) print('loading train ids from {}'.format(split_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) # Samplers for labeled data train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=batch_size_labeled, sampler=train_sampler, num_workers=num_workers, pin_memory=True) trainloader_iter = iter(trainloader) # Samplers for unlabeled data train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size_unlabeled, sampler=train_remain_sampler, num_workers=num_workers, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) # supervised loss supervised_loss = CrossEntropy2d(ignore_label=ignore_label).cuda() ''' Deeplab model ''' # Define network if deeplabv2: if pretraining == 'COCO': # coco and imagenet resnet architectures differ a little, just on how to do the stride from model.deeplabv2 import Res_Deeplab else: # imagenet pretrained (more modern modification) from model.deeplabv2_imagenet import Res_Deeplab # load pretrained parameters if pretraining == 'COCO': saved_state_dict = model_zoo.load_url( 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth' ) # COCO pretraining else: saved_state_dict = model_zoo.load_url( 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth' ) # iamgenet pretrainning else: from model.deeplabv3 import Res_Deeplab50 as Res_Deeplab saved_state_dict = model_zoo.load_url( 'https://download.pytorch.org/models/resnet50-19c8e357.pth' ) # iamgenet pretrainning # create network model = Res_Deeplab(num_classes=num_classes) # Copy loaded parameters to model new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) # Optimizer for segmentation network learning_rate_object = Learning_Rate_Object( config['training']['learning_rate']) optimizer = torch.optim.SGD(model.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) ema_model = create_ema_model(model, Res_Deeplab) ema_model.train() ema_model = ema_model.cuda() model.train() model = model.cuda() cudnn.benchmark = True if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) with open(checkpoint_dir + '/config.json', 'w') as handle: json.dump(config, handle, indent=4, sort_keys=False) pickle.dump(train_ids, open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb')) interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) epochs_since_start = 0 start_iteration = 0 best_mIoU = 0 # best metric while training iters_without_improve = 0 # TRAINING for i_iter in range(start_iteration, num_iterations): model.train() # set mode to training optimizer.zero_grad() loss_l_value = 0. adjust_learning_rate(optimizer, i_iter) labeled_turn = i_iter % 2 == 0 if labeled_turn: # labeled data optimization ''' LABELED SAMPLES ''' # Get batch try: batch = next(trainloader_iter) if batch[0].shape[0] != batch_size_labeled: batch = next(trainloader_iter) except: # finish epoch, rebuild the iterator epochs_since_start = epochs_since_start + 1 # print('Epochs since start: ',epochs_since_start) trainloader_iter = iter(trainloader) batch = next(trainloader_iter) images, labels, _, _, _ = batch images = images.cuda() labels = labels.cuda() model.train() if dataset == 'cityscapes': class_weights_curr.add_frequencies_labeled( labels.cpu().numpy()) images_aug, labels_aug, _, _ = augment_samples( images, labels, None, random.random() < 0.2, batch_size_labeled, ignore_label, weak=True) # labeled data labeled_pred, labeled_features = model(normalize( images_aug, dataset), return_features=True) labeled_pred = interp(labeled_pred) # apply class balance for cityspcaes dataset class_weights = torch.from_numpy(np.ones((num_classes))).cuda() if i_iter > RAMP_UP_ITERS and dataset == 'cityscapes': class_weights = torch.from_numpy( class_weights_curr.get_weights(num_iterations, only_labeled=False)).cuda() loss = 0 # SUPERVISED SEGMENTATION labeled_loss = supervised_loss(labeled_pred, labels_aug, weight=class_weights.float()) loss = loss + labeled_loss # CONTRASTIVE LEARNING if i_iter > RAMP_UP_ITERS - 1000: # Build Memory Bank 1000 iters before starting to do contrsative with torch.no_grad(): # Get feature vectors from labeled images with EMA model if use_teacher: labeled_pred_ema, labeled_features_ema = ema_model( normalize(images_aug, dataset), return_features=True) else: model.eval() labeled_pred_ema, labeled_features_ema = model( normalize(images_aug, dataset), return_features=True) model.train() labeled_pred_ema = interp(labeled_pred_ema) probability_prediction_ema, label_prediction_ema = torch.max( torch.softmax(labeled_pred_ema, dim=1), dim=1) # Get pseudolabels # Resize labels, predictions and probabilities, to feature map resolution labels_down = nn.functional.interpolate( labels_aug.float().unsqueeze(1), size=(labeled_features_ema.shape[2], labeled_features_ema.shape[3]), mode='nearest').squeeze(1) label_prediction_down = nn.functional.interpolate( label_prediction_ema.float().unsqueeze(1), size=(labeled_features_ema.shape[2], labeled_features_ema.shape[3]), mode='nearest').squeeze(1) probability_prediction_down = nn.functional.interpolate( probability_prediction_ema.float().unsqueeze(1), size=(labeled_features_ema.shape[2], labeled_features_ema.shape[3]), mode='nearest').squeeze(1) # get mask where the labeled predictions are correct and have a confidence higher than 0.95 mask_prediction_correctly = ( (label_prediction_down == labels_down).float() * (probability_prediction_down > 0.95).float()).bool() # Apply the filter mask to the features and its labels labeled_features_correct = labeled_features_ema.permute( 0, 2, 3, 1) labels_down_correct = labels_down[mask_prediction_correctly] labeled_features_correct = labeled_features_correct[ mask_prediction_correctly, ...] # get projected features with torch.no_grad(): if use_teacher: proj_labeled_features_correct = ema_model.projection_head( labeled_features_correct) else: model.eval() proj_labeled_features_correct = model.projection_head( labeled_features_correct) model.train() # updated memory bank feature_memory.add_features_from_sample_learned( ema_model, proj_labeled_features_correct, labels_down_correct, batch_size_labeled) if i_iter > RAMP_UP_ITERS: ''' CONTRASTIVE LEARNING ON LABELED DATA. Force features from labeled samples, to be similar to other features from the same class (which also leads to good predictions ''' # mask features that do not have ignore label in the labels (zero-padding because of data augmentation like resize/crop) mask_prediction_correctly = (labels_down != ignore_label) labeled_features_all = labeled_features.permute(0, 2, 3, 1) labels_down_all = labels_down[mask_prediction_correctly] labeled_features_all = labeled_features_all[ mask_prediction_correctly, ...] # get predicted features proj_labeled_features_all = model.projection_head( labeled_features_all) pred_labeled_features_all = model.prediction_head( proj_labeled_features_all) # Apply contrastive learning loss loss_contr_labeled = contrastive_class_to_class_learned_memory( model, pred_labeled_features_all, labels_down_all, num_classes, feature_memory.memory) loss = loss + loss_contr_labeled * 0.1 else: # unlabeled data optimization ''' UNLABELED SAMPLES ''' try: batch_remain = next(trainloader_remain_iter) if batch_remain[0].shape[0] != batch_size_unlabeled: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) # Unlabeled unlabeled_images, _, _, _, _ = batch_remain unlabeled_images = unlabeled_images.cuda() # Create pseudolabels with torch.no_grad(): if use_teacher: logits_u_w, features_weak_unlabeled = ema_model( normalize(unlabeled_images, dataset), return_features=True) else: model.eval() logits_u_w, features_weak_unlabeled = model( normalize(unlabeled_images, dataset), return_features=True) logits_u_w = interp( logits_u_w).detach() # prediction unlabeled softmax_u_w = torch.softmax(logits_u_w, dim=1) max_probs, pseudo_label = torch.max(softmax_u_w, dim=1) # Get pseudolabels model.train() if dataset == 'cityscapes': class_weights_curr.add_frequencies_unlabeled( pseudo_label.cpu().numpy()) ''' UNLABELED DATA ''' unlabeled_images_aug1, pseudo_label1, max_probs1, unlabeled_aug1_params = augment_samples( unlabeled_images, pseudo_label, max_probs, i_iter > RAMP_UP_ITERS and random.random() < 0.75, batch_size_unlabeled, ignore_label) unlabeled_images_aug2, pseudo_label2, max_probs2, unlabeled_aug2_params = augment_samples( unlabeled_images, pseudo_label, max_probs, i_iter > RAMP_UP_ITERS and random.random() < 0.75, batch_size_unlabeled, ignore_label) # concatenate two augmentations of unlabeled data joined_unlabeled = torch.cat( (unlabeled_images_aug1, unlabeled_images_aug2), dim=0) joined_pseudolabels = torch.cat((pseudo_label1, pseudo_label2), dim=0) joined_maxprobs = torch.cat((max_probs1, max_probs2), dim=0) pred_joined_unlabeled, features_joined_unlabeled = model( normalize(joined_unlabeled, dataset), return_features=True) pred_joined_unlabeled = interp(pred_joined_unlabeled) # apply clas balance for cityspcaes dataset if dataset == 'cityscapes': class_weights = torch.from_numpy( class_weights_curr.get_weights(num_iterations, only_labeled=False)).cuda() else: class_weights = torch.from_numpy(np.ones((num_classes))).cuda() loss = 0 # SELF-SUPERVISED SEGMENTATION unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted( ignore_index=ignore_label, weight=class_weights.float()).cuda() # # Pseudo-label weighting pixelWiseWeight = sigmoid_ramp_up( i_iter, RAMP_UP_ITERS) * torch.ones( joined_maxprobs.shape).cuda() pixelWiseWeight = pixelWiseWeight * torch.pow( joined_maxprobs.detach(), 6) # Pseudo-label loss loss_ce_unlabeled = unlabeled_loss(pred_joined_unlabeled, joined_pseudolabels, pixelWiseWeight) loss = loss + loss_ce_unlabeled # entropy loss valid_mask = (joined_pseudolabels != ignore_label).unsqueeze(1) loss = loss + entropy_loss( torch.nn.functional.softmax(pred_joined_unlabeled, dim=1), valid_mask) * 0.01 if i_iter > RAMP_UP_ITERS: ''' CONTRASTIVE LEARNING ON UNLABELED DATA. align unlabeled features to labeled features ''' joined_pseudolabels_down = nn.functional.interpolate( joined_pseudolabels.float().unsqueeze(1), size=(features_joined_unlabeled.shape[2], features_joined_unlabeled.shape[3]), mode='nearest').squeeze(1) # mask features that do not have ignore label in the labels (zero-padding because of data augmentation like resize/crop) mask = (joined_pseudolabels_down != ignore_label) features_joined_unlabeled = features_joined_unlabeled.permute( 0, 2, 3, 1) features_joined_unlabeled = features_joined_unlabeled[mask, ...] joined_pseudolabels_down = joined_pseudolabels_down[mask] # get predicted features proj_feat_unlabeled = model.projection_head( features_joined_unlabeled) pred_feat_unlabeled = model.prediction_head( proj_feat_unlabeled) # Apply contrastive learning loss loss_contr_unlabeled = contrastive_class_to_class_learned_memory( model, pred_feat_unlabeled, joined_pseudolabels_down, num_classes, feature_memory.memory) loss = loss + loss_contr_unlabeled * 0.1 # common code loss_l_value += loss.item() # optimize loss.backward() optimizer.step() m = 1 - (1 - 0.995) * (math.cos(math.pi * i_iter / num_iterations) + 1) / 2 ema_model = update_ema_variables(ema_model=ema_model, model=model, alpha_teacher=m, iteration=i_iter) if i_iter % save_checkpoint_every == 0 and i_iter != 0: _save_checkpoint(i_iter, model, optimizer, config) if i_iter % val_per_iter == 0 and i_iter != 0: print('iter = {0:6d}/{1:6d}'.format(i_iter, num_iterations)) model.eval() mIoU, eval_loss = evaluate(model, dataset, deeplabv2=deeplabv2, ignore_label=ignore_label, save_dir=checkpoint_dir, pretraining=pretraining) model.train() if mIoU > best_mIoU: best_mIoU = mIoU if save_teacher: _save_checkpoint(i_iter, ema_model, optimizer, config, save_best=True) else: _save_checkpoint(i_iter, model, optimizer, config, save_best=True) iters_without_improve = 0 else: iters_without_improve += val_per_iter ''' if the performance has not improve in N iterations, try to reload best model to optimize again with a lower LR Simulating an iterative training''' if iters_without_improve > num_iterations / 5.: print('Re-loading a previous best model') checkpoint = torch.load( os.path.join(checkpoint_dir, f'best_model.pth')) model.load_state_dict(checkpoint['model']) ema_model = create_ema_model(model, Res_Deeplab) ema_model.train() ema_model = ema_model.cuda() model.train() model = model.cuda() iters_without_improve = 0 # reset timer _save_checkpoint(num_iterations, model, optimizer, config) # FINISH TRAINING, evaluate again model.eval() mIoU, eval_loss = evaluate(model, dataset, deeplabv2=deeplabv2, ignore_label=ignore_label, save_dir=checkpoint_dir, pretraining=pretraining) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, save_best=True) # TRY IMPROVING BEST MODEL WITH EMA MODEL OR UPDATING BN STATS # Load best model checkpoint = torch.load(os.path.join(checkpoint_dir, f'best_model.pth')) model.load_state_dict(checkpoint['model']) model = model.cuda() model = update_BN_weak_unlabeled_data(model, normalize, batch_size_unlabeled, trainloader_remain) model.eval() mIoU, eval_loss = evaluate(model, dataset, deeplabv2=deeplabv2, ignore_label=ignore_label, save_dir=checkpoint_dir, pretraining=pretraining) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, save_best=True) print('BEST MIOU') print(max(best_mIoU_improved, best_mIoU)) end = timeit.default_timer() print('Total time: ' + str(end - start) + ' seconds')
def main(): print(config) cudnn.enabled = True torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) np.random.seed(random_seed) random.seed(random_seed) torch.backends.cudnn.deterministic = True if pretraining == 'COCO': from utils.transformsgpu import normalize_bgr as normalize else: from utils.transformsgpu import normalize_rgb as normalize batch_size_unlabeled = int(batch_size / 2) batch_size_labeled = int(batch_size * 1) RAMP_UP_ITERS = 2000 data_loader = get_loader('cityscapes') data_path = get_data_path('cityscapes') data_aug = Compose([ RandomCrop_city(input_size) ]) # from 1024x2048 to resize 512x1024 to crop input_size (512x512) train_dataset = data_loader(data_path, is_transform=True, augmentations=data_aug, img_size=input_size, pretraining=pretraining) from data.gta5_loader import gtaLoader data_loader_gta = gtaLoader data_path_gta = get_data_path('gta5') data_aug_gta = Compose([ RandomCrop_city(input_size) ]) # from 1024x2048 to resize 512x1024 to crop input_size (512x512) train_dataset_gta = data_loader_gta(data_path_gta, is_transform=True, augmentations=data_aug_gta, img_size=input_size, pretraining=pretraining) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) partial_size = labeled_samples print('Training on number of samples:', partial_size) class_weights_curr = ClassBalancing( labeled_iters=int(labeled_samples / batch_size_labeled), unlabeled_iters=int( (train_dataset_size - labeled_samples) / batch_size_unlabeled), n_classes=num_classes) feature_memory = FeatureMemory(num_samples=labeled_samples, dataset=dataset, memory_per_class=256, feature_size=256, n_classes=num_classes) # select the partition if split_id is not None: train_ids = pickle.load(open(split_id, 'rb')) print('loading train ids from {}'.format(split_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=batch_size_labeled, sampler=train_sampler, num_workers=num_workers, pin_memory=True) trainloader_iter = iter(trainloader) # GTA5 train_ids_gta = np.arange(len(train_dataset_gta)) np.random.shuffle(train_ids_gta) train_sampler_gta = data.sampler.SubsetRandomSampler(train_ids_gta) trainloader_gta = data.DataLoader(train_dataset_gta, batch_size=batch_size_labeled, sampler=train_sampler_gta, num_workers=num_workers, pin_memory=True) trainloader_iter_gta = iter(trainloader_gta) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size_unlabeled, sampler=train_remain_sampler, num_workers=num_workers, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) # LOSSES unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda() supervised_loss = CrossEntropy2d(ignore_label=ignore_label).cuda() ''' Deeplab model ''' # Define network if deeplabv2: if pretraining == 'COCO': # coco and iamgenet resnet architectures differ a little, just on how to do the stride from model.deeplabv2 import Res_Deeplab else: # imagenet pretrained (more modern modification) from model.deeplabv2_imagenet import Res_Deeplab else: from model.deeplabv3 import Res_Deeplab # create network model = Res_Deeplab(num_classes=num_classes) # load pretrained parameters if pretraining == 'COCO': saved_state_dict = model_zoo.load_url( 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth' ) # COCO pretraining else: saved_state_dict = model_zoo.load_url( 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth' ) # iamgenet pretrainning # Copy loaded parameters to model new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) # Optimizer for segmentation network learning_rate_object = Learning_Rate_Object( config['training']['learning_rate']) optimizer = torch.optim.SGD(model.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) ema_model = create_ema_model(model, Res_Deeplab) ema_model.train() ema_model = ema_model.cuda() model.train() model = model.cuda() cudnn.benchmark = True # checkpoint = torch.load('/home/snowflake/Escritorio/Semi-Sup/saved/Deep_cont/best_model.pth') # model.load_state_dict(checkpoint['model']) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) with open(checkpoint_dir + '/config.json', 'w') as handle: json.dump(config, handle, indent=4, sort_keys=False) pickle.dump(train_ids, open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb')) interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) epochs_since_start = 0 start_iteration = 0 best_mIoU = 0 # best metric while training iters_without_improve = 0 # TRAINING for i_iter in range(start_iteration, num_iterations): model.train() # set mode to training optimizer.zero_grad() a = time.time() loss_l_value = 0. adjust_learning_rate(optimizer, i_iter) ''' LABELED SAMPLES ''' # Get batch is_cityscapes = i_iter % 2 == 0 is_gta = not is_cityscapes if num_iterations - i_iter > 100: # Last 100 itereations only citysacpes data is_cityscapes = True if is_cityscapes: try: batch = next(trainloader_iter) if batch[0].shape[0] != batch_size_labeled: batch = next(trainloader_iter) except: # finish epoch, rebuild the iterator epochs_since_start = epochs_since_start + 1 # print('Epochs since start: ',epochs_since_start) trainloader_iter = iter(trainloader) batch = next(trainloader_iter) else: try: batch = next(trainloader_iter_gta) if batch[0].shape[0] != batch_size_labeled: train_ids_gta = np.arange(len(train_dataset_gta)) np.random.shuffle(train_ids_gta) train_sampler_gta = data.sampler.SubsetRandomSampler( train_ids_gta) trainloader_gta = data.DataLoader( train_dataset_gta, batch_size=batch_size_labeled, sampler=train_sampler_gta, num_workers=num_workers, pin_memory=True) trainloader_iter_gta = iter(trainloader_gta) batch = next(trainloader_iter_gta) except: # finish epoch, rebuild the iterator # print('Epochs since start: ',epochs_since_start) trainloader_iter_gta = iter(trainloader_gta) batch = next(trainloader_iter_gta) images, labels, _, _, _ = batch images = images.cuda() labels = labels.cuda() ''' UNLABELED SAMPLES ''' try: batch_remain = next(trainloader_remain_iter) if batch_remain[0].shape[0] != batch_size_unlabeled: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) # Unlabeled unlabeled_images, _, _, _, _ = batch_remain unlabeled_images = unlabeled_images.cuda() # Create pseudolabels with torch.no_grad(): if use_teacher: logits_u_w, features_weak_unlabeled = ema_model( normalize(unlabeled_images, dataset), return_features=True) else: model.eval() logits_u_w, features_weak_unlabeled = model( normalize(unlabeled_images, dataset), return_features=True) logits_u_w = interp(logits_u_w).detach() # prediction unlabeled softmax_u_w = torch.softmax(logits_u_w, dim=1) max_probs, pseudo_label = torch.max(softmax_u_w, dim=1) # Get pseudolabels model.train() if is_cityscapes: class_weights_curr.add_frequencies(labels.cpu().numpy(), pseudo_label.cpu().numpy()) images2, labels2, _, _ = augment_samples(images, labels, None, random.random() < 0.25, batch_size_labeled, ignore_label, weak=True) ''' UNLABELED DATA ''' ''' CROSS ENTROPY FOR UNLABELED USING PSEUDOLABELS Once you have the speudolabel, perform strong augmetnation to force the netowrk to yield lower confidence scores for pushing them up ''' do_classmix = i_iter > RAMP_UP_ITERS and random.random( ) < 0.75 # only after rampup perfrom classmix unlabeled_images_aug1, pseudo_label1, max_probs1, unlabeled_aug1_params = augment_samples( unlabeled_images, pseudo_label, max_probs, do_classmix, batch_size_unlabeled, ignore_label) do_classmix = i_iter > RAMP_UP_ITERS and random.random( ) < 0.75 # only after rampup perfrom classmix unlabeled_images_aug2, pseudo_label2, max_probs2, unlabeled_aug2_params = augment_samples( unlabeled_images, pseudo_label, max_probs, do_classmix, batch_size_unlabeled, ignore_label) joined_unlabeled = torch.cat( (unlabeled_images_aug1, unlabeled_images_aug2), dim=0) joined_pseudolabels = torch.cat((pseudo_label1, pseudo_label2), dim=0) joined_maxprobs = torch.cat((max_probs1, max_probs2), dim=0) pred_joined_unlabeled, features_joined_unlabeled = model( normalize(joined_unlabeled, dataset), return_features=True) pred_joined_unlabeled = interp(pred_joined_unlabeled) joined_labeled = images2 joined_labels = labels2 labeled_pred, labeled_features = model(normalize( joined_labeled, dataset), return_features=True) labeled_pred = interp(labeled_pred) class_weights = torch.from_numpy(np.ones((num_classes))).cuda() if i_iter > RAMP_UP_ITERS: class_weights = torch.from_numpy( class_weights_curr.get_weights(num_iterations, only_labeled=False)).cuda() loss = 0 # SUPERVISED SEGMENTATION labeled_loss = supervised_loss(labeled_pred, joined_labels, weight=class_weights.float()) # loss = loss + labeled_loss # SELF-SUPERVISED SEGMENTATION ''' Cross entropy loss using pseudolabels. ''' unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted( ignore_index=ignore_label, weight=class_weights.float()).cuda() # # Pseudo-label weighting pixelWiseWeight = sigmoid_ramp_up(i_iter, RAMP_UP_ITERS) * torch.ones( joined_maxprobs.shape).cuda() pixelWiseWeight = pixelWiseWeight * torch.pow(joined_maxprobs.detach(), 6) # Pseudo-label loss loss_ce_unlabeled = unlabeled_loss(pred_joined_unlabeled, joined_pseudolabels, pixelWiseWeight) loss = loss + loss_ce_unlabeled # entropy loss valid_mask = (joined_pseudolabels != ignore_label).unsqueeze(1) loss = loss + entropy_loss( torch.nn.functional.softmax(pred_joined_unlabeled, dim=1), valid_mask) * 0.01 # CONTRASTIVE LEARNING if is_cityscapes: if i_iter > RAMP_UP_ITERS - 1000: # Build Memory Bank 1000 iters before starting to do contrsative with torch.no_grad(): if use_teacher: labeled_pred_ema, labeled_features_ema = ema_model( normalize(joined_labeled, dataset), return_features=True) else: model.eval() labeled_pred_ema, labeled_features_ema = model( normalize(joined_labeled, dataset), return_features=True) model.train() labeled_pred_ema = interp(labeled_pred_ema) probability_prediction_ema, label_prediction_ema = torch.max( torch.softmax(labeled_pred_ema, dim=1), dim=1) # Get pseudolabels labels_down = nn.functional.interpolate( joined_labels.float().unsqueeze(1), size=(labeled_features_ema.shape[2], labeled_features_ema.shape[3]), mode='nearest').squeeze(1) label_prediction_down = nn.functional.interpolate( label_prediction_ema.float().unsqueeze(1), size=(labeled_features_ema.shape[2], labeled_features_ema.shape[3]), mode='nearest').squeeze(1) probability_prediction_down = nn.functional.interpolate( probability_prediction_ema.float().unsqueeze(1), size=(labeled_features_ema.shape[2], labeled_features_ema.shape[3]), mode='nearest').squeeze(1) # get mask where the labeled predictions are correct mask_prediction_correctly = ( (label_prediction_down == labels_down).float() * (probability_prediction_down > 0.95).float()).bool() labeled_features_correct = labeled_features_ema.permute( 0, 2, 3, 1) labels_down_correct = labels_down[mask_prediction_correctly] labeled_features_correct = labeled_features_correct[ mask_prediction_correctly, ...] # get projected features with torch.no_grad(): if use_teacher: proj_labeled_features_correct = ema_model.projection_head( labeled_features_correct) else: model.eval() proj_labeled_features_correct = model.projection_head( labeled_features_correct) model.train() # updated memory bank feature_memory.add_features_from_sample_learned( ema_model, proj_labeled_features_correct, labels_down_correct, batch_size_labeled) if i_iter > RAMP_UP_ITERS: ''' LABELED TO LABELED. Force features from laeled samples, to be similar to other features from the same class (which also leads to good predictions) ''' # now we can take all. as they are not the prototypes, here we are gonan force these features to be similar as the correct ones mask_prediction_correctly = (labels_down != ignore_label) labeled_features_all = labeled_features.permute(0, 2, 3, 1) labels_down_all = labels_down[mask_prediction_correctly] labeled_features_all = labeled_features_all[ mask_prediction_correctly, ...] # get prediction features proj_labeled_features_all = model.projection_head( labeled_features_all) pred_labeled_features_all = model.prediction_head( proj_labeled_features_all) loss_contr_labeled = contrastive_class_to_class_learned_memory( model, pred_labeled_features_all, labels_down_all, num_classes, feature_memory.memory) loss = loss + loss_contr_labeled * 0.2 ''' CONTRASTIVE LEARNING ON UNLABELED DATA. align unlabeled features to labeled features ''' joined_pseudolabels_down = nn.functional.interpolate( joined_pseudolabels.float().unsqueeze(1), size=(features_joined_unlabeled.shape[2], features_joined_unlabeled.shape[3]), mode='nearest').squeeze(1) # take out the features from black pixels from zooms out and augmetnations (ignore labels on pseduoalebl) mask = (joined_pseudolabels_down != ignore_label) features_joined_unlabeled = features_joined_unlabeled.permute( 0, 2, 3, 1) features_joined_unlabeled = features_joined_unlabeled[mask, ...] joined_pseudolabels_down = joined_pseudolabels_down[mask] # get projected features proj_feat_unlabeled = model.projection_head( features_joined_unlabeled) pred_feat_unlabeled = model.prediction_head(proj_feat_unlabeled) loss_contr_unlabeled = contrastive_class_to_class_learned_memory( model, pred_feat_unlabeled, joined_pseudolabels_down, num_classes, feature_memory.memory) loss = loss + loss_contr_unlabeled * 0.2 loss_l_value += loss.item() # optimize loss.backward() optimizer.step() m = 1 - (1 - 0.995) * (math.cos(math.pi * i_iter / num_iterations) + 1) / 2 ema_model = update_ema_variables(ema_model=ema_model, model=model, alpha_teacher=m, iteration=i_iter) # print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format(i_iter, num_iterations, loss_l_value)) if i_iter % save_checkpoint_every == 0 and i_iter != 0: _save_checkpoint(i_iter, model, optimizer, config) if i_iter % val_per_iter == 0 and i_iter != 0: print('iter = {0:6d}/{1:6d}'.format(i_iter, num_iterations)) model.eval() mIoU, eval_loss = evaluate(model, dataset, ignore_label=ignore_label, save_dir=checkpoint_dir, pretraining=pretraining) model.train() if mIoU > best_mIoU: best_mIoU = mIoU if save_teacher: _save_checkpoint(i_iter, ema_model, optimizer, config, save_best=True) else: _save_checkpoint(i_iter, model, optimizer, config, save_best=True) iters_without_improve = 0 else: iters_without_improve += val_per_iter ''' if the performance has not improve in N iterations, try to reload best model to optimize again with a lower LR Simulating an iterative training''' if iters_without_improve > num_iterations / 5.: print('Re-loading a previous best model') checkpoint = torch.load( os.path.join(checkpoint_dir, f'best_model.pth')) model.load_state_dict(checkpoint['model']) ema_model = create_ema_model(model, Res_Deeplab) ema_model.train() ema_model = ema_model.cuda() model.train() model = model.cuda() iters_without_improve = 0 # reset timer _save_checkpoint(num_iterations, model, optimizer, config) # FINISH TRAINING, evaluate again model.eval() mIoU, eval_loss = evaluate(model, dataset, deeplabv2=deeplabv2, ignore_label=ignore_label, save_dir=checkpoint_dir, pretraining=pretraining) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, save_best=True) # TRY IMPROVING BEST MODEL WITH EMA MODEL OR UPDATING BN STATS # Load best model checkpoint = torch.load(os.path.join(checkpoint_dir, f'best_model.pth')) model.load_state_dict(checkpoint['model']) model = model.cuda() model = update_BN_weak_unlabeled_data(model, normalize, batch_size_unlabeled, trainloader_remain) model.eval() mIoU, eval_loss = evaluate(model, dataset, deeplabv2=deeplabv2, ignore_label=ignore_label, save_dir=checkpoint_dir, pretraining=pretraining) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, save_best=True) print('BEST MIOU') print(max(best_mIoU_improved, best_mIoU)) end = timeit.default_timer() print('Total time: ' + str(end - start) + ' seconds')