コード例 #1
0
def main():
    print(config)
    cudnn.enabled = True
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.backends.cudnn.deterministic = True

    if pretraining == 'COCO':  # depending the pretraining, normalize with bgr or rgb
        from utils.transformsgpu import normalize_bgr as normalize
    else:
        from utils.transformsgpu import normalize_rgb as normalize

    batch_size_unlabeled = int(
        batch_size /
        2)  # because of augmentation anchoring, 2 augmentations per sample
    batch_size_labeled = int(batch_size * 1)
    assert batch_size_unlabeled >= 2, "batch size should be higher than 2"
    assert batch_size_labeled >= 2, "batch size should be higher than 2"
    RAMP_UP_ITERS = 2000  # iterations until contrastive and self-training are taken into account

    # DATASETS / LOADERS
    if dataset == 'pascal_voc':
        data_loader = get_loader(dataset)
        data_path = get_data_path(dataset)
        train_dataset = data_loader(data_path,
                                    crop_size=input_size,
                                    scale=False,
                                    mirror=False,
                                    pretraining=pretraining)

    elif dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        if deeplabv2:
            data_aug = Compose([RandomCrop_city(input_size)])
        else:  # for deeplabv3 original resolution
            data_aug = Compose([RandomCrop_city_highres(input_size)])
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    pretraining=pretraining)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)

    # class weighting  taken unlabeled data into acount in an incremental fashion.
    class_weights_curr = ClassBalancing(
        labeled_iters=int(labeled_samples / batch_size_labeled),
        unlabeled_iters=int(
            (train_dataset_size - labeled_samples) / batch_size_unlabeled),
        n_classes=num_classes)
    # Memory Bank
    feature_memory = FeatureMemory(num_samples=labeled_samples,
                                   dataset=dataset,
                                   memory_per_class=256,
                                   feature_size=256,
                                   n_classes=num_classes)

    # select the partition
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    # Samplers for labeled data
    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size_labeled,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    # Samplers for unlabeled data
    train_remain_sampler = data.sampler.SubsetRandomSampler(
        train_ids[partial_size:])
    trainloader_remain = data.DataLoader(train_dataset,
                                         batch_size=batch_size_unlabeled,
                                         sampler=train_remain_sampler,
                                         num_workers=num_workers,
                                         pin_memory=True)
    trainloader_remain_iter = iter(trainloader_remain)

    # supervised loss
    supervised_loss = CrossEntropy2d(ignore_label=ignore_label).cuda()
    ''' Deeplab model '''
    # Define network
    if deeplabv2:
        if pretraining == 'COCO':  # coco and imagenet resnet architectures differ a little, just on how to do the stride
            from model.deeplabv2 import Res_Deeplab
        else:  # imagenet pretrained (more modern modification)
            from model.deeplabv2_imagenet import Res_Deeplab

        # load pretrained parameters
        if pretraining == 'COCO':
            saved_state_dict = model_zoo.load_url(
                'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth'
            )  # COCO pretraining
        else:
            saved_state_dict = model_zoo.load_url(
                'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
            )  # iamgenet pretrainning

    else:
        from model.deeplabv3 import Res_Deeplab50 as Res_Deeplab
        saved_state_dict = model_zoo.load_url(
            'https://download.pytorch.org/models/resnet50-19c8e357.pth'
        )  # iamgenet pretrainning

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    optimizer = torch.optim.SGD(model.optim_parameters(learning_rate_object),
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    ema_model = create_ema_model(model, Res_Deeplab)
    ema_model.train()
    ema_model = ema_model.cuda()
    model.train()
    model = model.cuda()
    cudnn.benchmark = True

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    epochs_since_start = 0
    start_iteration = 0
    best_mIoU = 0  # best metric while training
    iters_without_improve = 0

    # TRAINING
    for i_iter in range(start_iteration, num_iterations):
        model.train()  # set mode to training
        optimizer.zero_grad()

        loss_l_value = 0.
        adjust_learning_rate(optimizer, i_iter)

        labeled_turn = i_iter % 2 == 0

        if labeled_turn:  # labeled data optimization
            ''' LABELED SAMPLES '''
            # Get batch
            try:
                batch = next(trainloader_iter)
                if batch[0].shape[0] != batch_size_labeled:
                    batch = next(trainloader_iter)
            except:  # finish epoch, rebuild the iterator
                epochs_since_start = epochs_since_start + 1
                # print('Epochs since start: ',epochs_since_start)
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)

            images, labels, _, _, _ = batch
            images = images.cuda()
            labels = labels.cuda()

            model.train()

            if dataset == 'cityscapes':
                class_weights_curr.add_frequencies_labeled(
                    labels.cpu().numpy())

            images_aug, labels_aug, _, _ = augment_samples(
                images,
                labels,
                None,
                random.random() < 0.2,
                batch_size_labeled,
                ignore_label,
                weak=True)

            # labeled data
            labeled_pred, labeled_features = model(normalize(
                images_aug, dataset),
                                                   return_features=True)
            labeled_pred = interp(labeled_pred)

            # apply class balance for cityspcaes dataset
            class_weights = torch.from_numpy(np.ones((num_classes))).cuda()
            if i_iter > RAMP_UP_ITERS and dataset == 'cityscapes':
                class_weights = torch.from_numpy(
                    class_weights_curr.get_weights(num_iterations,
                                                   only_labeled=False)).cuda()

            loss = 0

            # SUPERVISED SEGMENTATION
            labeled_loss = supervised_loss(labeled_pred,
                                           labels_aug,
                                           weight=class_weights.float())
            loss = loss + labeled_loss

            # CONTRASTIVE LEARNING
            if i_iter > RAMP_UP_ITERS - 1000:
                # Build Memory Bank 1000 iters before starting to do contrsative

                with torch.no_grad():
                    # Get feature vectors from labeled images with EMA model
                    if use_teacher:
                        labeled_pred_ema, labeled_features_ema = ema_model(
                            normalize(images_aug, dataset),
                            return_features=True)
                    else:
                        model.eval()
                        labeled_pred_ema, labeled_features_ema = model(
                            normalize(images_aug, dataset),
                            return_features=True)
                        model.train()
                    labeled_pred_ema = interp(labeled_pred_ema)
                    probability_prediction_ema, label_prediction_ema = torch.max(
                        torch.softmax(labeled_pred_ema,
                                      dim=1), dim=1)  # Get pseudolabels

                # Resize labels, predictions and probabilities,  to feature map resolution
                labels_down = nn.functional.interpolate(
                    labels_aug.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                label_prediction_down = nn.functional.interpolate(
                    label_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                probability_prediction_down = nn.functional.interpolate(
                    probability_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)

                # get mask where the labeled predictions are correct and have a confidence higher than 0.95
                mask_prediction_correctly = (
                    (label_prediction_down == labels_down).float() *
                    (probability_prediction_down > 0.95).float()).bool()

                # Apply the filter mask to the features and its labels
                labeled_features_correct = labeled_features_ema.permute(
                    0, 2, 3, 1)
                labels_down_correct = labels_down[mask_prediction_correctly]
                labeled_features_correct = labeled_features_correct[
                    mask_prediction_correctly, ...]

                # get projected features
                with torch.no_grad():
                    if use_teacher:
                        proj_labeled_features_correct = ema_model.projection_head(
                            labeled_features_correct)
                    else:
                        model.eval()
                        proj_labeled_features_correct = model.projection_head(
                            labeled_features_correct)
                        model.train()

                # updated memory bank
                feature_memory.add_features_from_sample_learned(
                    ema_model, proj_labeled_features_correct,
                    labels_down_correct, batch_size_labeled)

            if i_iter > RAMP_UP_ITERS:
                '''
                CONTRASTIVE LEARNING ON LABELED DATA. Force features from labeled samples, to be similar to other features from the same class (which also leads to good predictions
                '''
                # mask features that do not have ignore label in the labels (zero-padding because of data augmentation like resize/crop)
                mask_prediction_correctly = (labels_down != ignore_label)

                labeled_features_all = labeled_features.permute(0, 2, 3, 1)
                labels_down_all = labels_down[mask_prediction_correctly]
                labeled_features_all = labeled_features_all[
                    mask_prediction_correctly, ...]

                # get predicted features
                proj_labeled_features_all = model.projection_head(
                    labeled_features_all)
                pred_labeled_features_all = model.prediction_head(
                    proj_labeled_features_all)

                # Apply contrastive learning loss
                loss_contr_labeled = contrastive_class_to_class_learned_memory(
                    model, pred_labeled_features_all, labels_down_all,
                    num_classes, feature_memory.memory)

                loss = loss + loss_contr_labeled * 0.1

        else:  # unlabeled data optimization
            ''' UNLABELED SAMPLES '''
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size_unlabeled:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            # Unlabeled
            unlabeled_images, _, _, _, _ = batch_remain
            unlabeled_images = unlabeled_images.cuda()

            # Create pseudolabels
            with torch.no_grad():
                if use_teacher:
                    logits_u_w, features_weak_unlabeled = ema_model(
                        normalize(unlabeled_images, dataset),
                        return_features=True)
                else:
                    model.eval()
                    logits_u_w, features_weak_unlabeled = model(
                        normalize(unlabeled_images, dataset),
                        return_features=True)
                logits_u_w = interp(
                    logits_u_w).detach()  # prediction unlabeled
                softmax_u_w = torch.softmax(logits_u_w, dim=1)
                max_probs, pseudo_label = torch.max(softmax_u_w,
                                                    dim=1)  # Get pseudolabels

            model.train()

            if dataset == 'cityscapes':
                class_weights_curr.add_frequencies_unlabeled(
                    pseudo_label.cpu().numpy())
            '''
            UNLABELED DATA
            '''
            unlabeled_images_aug1, pseudo_label1, max_probs1, unlabeled_aug1_params = augment_samples(
                unlabeled_images, pseudo_label, max_probs,
                i_iter > RAMP_UP_ITERS and random.random() < 0.75,
                batch_size_unlabeled, ignore_label)

            unlabeled_images_aug2, pseudo_label2, max_probs2, unlabeled_aug2_params = augment_samples(
                unlabeled_images, pseudo_label, max_probs,
                i_iter > RAMP_UP_ITERS and random.random() < 0.75,
                batch_size_unlabeled, ignore_label)
            # concatenate two augmentations of unlabeled data
            joined_unlabeled = torch.cat(
                (unlabeled_images_aug1, unlabeled_images_aug2), dim=0)
            joined_pseudolabels = torch.cat((pseudo_label1, pseudo_label2),
                                            dim=0)
            joined_maxprobs = torch.cat((max_probs1, max_probs2), dim=0)

            pred_joined_unlabeled, features_joined_unlabeled = model(
                normalize(joined_unlabeled, dataset), return_features=True)
            pred_joined_unlabeled = interp(pred_joined_unlabeled)

            # apply clas balance for cityspcaes dataset
            if dataset == 'cityscapes':
                class_weights = torch.from_numpy(
                    class_weights_curr.get_weights(num_iterations,
                                                   only_labeled=False)).cuda()
            else:
                class_weights = torch.from_numpy(np.ones((num_classes))).cuda()

            loss = 0

            # SELF-SUPERVISED SEGMENTATION
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
                ignore_index=ignore_label,
                weight=class_weights.float()).cuda()  #

            # Pseudo-label weighting
            pixelWiseWeight = sigmoid_ramp_up(
                i_iter, RAMP_UP_ITERS) * torch.ones(
                    joined_maxprobs.shape).cuda()
            pixelWiseWeight = pixelWiseWeight * torch.pow(
                joined_maxprobs.detach(), 6)

            # Pseudo-label loss
            loss_ce_unlabeled = unlabeled_loss(pred_joined_unlabeled,
                                               joined_pseudolabels,
                                               pixelWiseWeight)

            loss = loss + loss_ce_unlabeled

            # entropy loss
            valid_mask = (joined_pseudolabels != ignore_label).unsqueeze(1)
            loss = loss + entropy_loss(
                torch.nn.functional.softmax(pred_joined_unlabeled, dim=1),
                valid_mask) * 0.01

            if i_iter > RAMP_UP_ITERS:
                '''
                CONTRASTIVE LEARNING ON UNLABELED DATA. align unlabeled features to labeled features
                '''
                joined_pseudolabels_down = nn.functional.interpolate(
                    joined_pseudolabels.float().unsqueeze(1),
                    size=(features_joined_unlabeled.shape[2],
                          features_joined_unlabeled.shape[3]),
                    mode='nearest').squeeze(1)

                # mask features that do not have ignore label in the labels (zero-padding because of data augmentation like resize/crop)
                mask = (joined_pseudolabels_down != ignore_label)

                features_joined_unlabeled = features_joined_unlabeled.permute(
                    0, 2, 3, 1)
                features_joined_unlabeled = features_joined_unlabeled[mask,
                                                                      ...]
                joined_pseudolabels_down = joined_pseudolabels_down[mask]

                # get predicted features
                proj_feat_unlabeled = model.projection_head(
                    features_joined_unlabeled)
                pred_feat_unlabeled = model.prediction_head(
                    proj_feat_unlabeled)

                # Apply contrastive learning loss
                loss_contr_unlabeled = contrastive_class_to_class_learned_memory(
                    model, pred_feat_unlabeled, joined_pseudolabels_down,
                    num_classes, feature_memory.memory)

                loss = loss + loss_contr_unlabeled * 0.1

        # common code

        loss_l_value += loss.item()

        # optimize
        loss.backward()
        optimizer.step()

        m = 1 - (1 - 0.995) * (math.cos(math.pi * i_iter / num_iterations) +
                               1) / 2
        ema_model = update_ema_variables(ema_model=ema_model,
                                         model=model,
                                         alpha_teacher=m,
                                         iteration=i_iter)

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config)

        if i_iter % val_per_iter == 0 and i_iter != 0:
            print('iter = {0:6d}/{1:6d}'.format(i_iter, num_iterations))

            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       deeplabv2=deeplabv2,
                                       ignore_label=ignore_label,
                                       save_dir=checkpoint_dir,
                                       pretraining=pretraining)
            model.train()

            if mIoU > best_mIoU:
                best_mIoU = mIoU
                if save_teacher:
                    _save_checkpoint(i_iter,
                                     ema_model,
                                     optimizer,
                                     config,
                                     save_best=True)
                else:
                    _save_checkpoint(i_iter,
                                     model,
                                     optimizer,
                                     config,
                                     save_best=True)
                iters_without_improve = 0
            else:
                iters_without_improve += val_per_iter
            '''
            if the performance has not improve in N iterations, try to reload best model to optimize again with a lower LR
            Simulating an iterative training'''
            if iters_without_improve > num_iterations / 5.:
                print('Re-loading a previous best model')
                checkpoint = torch.load(
                    os.path.join(checkpoint_dir, f'best_model.pth'))
                model.load_state_dict(checkpoint['model'])
                ema_model = create_ema_model(model, Res_Deeplab)
                ema_model.train()
                ema_model = ema_model.cuda()
                model.train()
                model = model.cuda()
                iters_without_improve = 0  # reset timer

    _save_checkpoint(num_iterations, model, optimizer, config)

    # FINISH TRAINING, evaluate again
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()

    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    # TRY IMPROVING BEST MODEL WITH EMA MODEL OR UPDATING BN STATS

    # Load best model
    checkpoint = torch.load(os.path.join(checkpoint_dir, f'best_model.pth'))
    model.load_state_dict(checkpoint['model'])
    model = model.cuda()

    model = update_BN_weak_unlabeled_data(model, normalize,
                                          batch_size_unlabeled,
                                          trainloader_remain)
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    print('BEST MIOU')
    print(max(best_mIoU_improved, best_mIoU))

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')
def main():
    print(config)
    cudnn.enabled = True
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.backends.cudnn.deterministic = True

    if pretraining == 'COCO':
        from utils.transformsgpu import normalize_bgr as normalize
    else:
        from utils.transformsgpu import normalize_rgb as normalize

    batch_size_unlabeled = int(batch_size / 2)
    batch_size_labeled = int(batch_size * 1)

    RAMP_UP_ITERS = 2000

    data_loader = get_loader('cityscapes')
    data_path = get_data_path('cityscapes')
    data_aug = Compose([
        RandomCrop_city(input_size)
    ])  # from 1024x2048 to resize 512x1024 to crop input_size (512x512)
    train_dataset = data_loader(data_path,
                                is_transform=True,
                                augmentations=data_aug,
                                img_size=input_size,
                                pretraining=pretraining)

    from data.gta5_loader import gtaLoader
    data_loader_gta = gtaLoader
    data_path_gta = get_data_path('gta5')
    data_aug_gta = Compose([
        RandomCrop_city(input_size)
    ])  # from 1024x2048 to resize 512x1024 to crop input_size (512x512)
    train_dataset_gta = data_loader_gta(data_path_gta,
                                        is_transform=True,
                                        augmentations=data_aug_gta,
                                        img_size=input_size,
                                        pretraining=pretraining)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)

    class_weights_curr = ClassBalancing(
        labeled_iters=int(labeled_samples / batch_size_labeled),
        unlabeled_iters=int(
            (train_dataset_size - labeled_samples) / batch_size_unlabeled),
        n_classes=num_classes)

    feature_memory = FeatureMemory(num_samples=labeled_samples,
                                   dataset=dataset,
                                   memory_per_class=256,
                                   feature_size=256,
                                   n_classes=num_classes)

    # select the partition
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size_labeled,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    # GTA5
    train_ids_gta = np.arange(len(train_dataset_gta))
    np.random.shuffle(train_ids_gta)
    train_sampler_gta = data.sampler.SubsetRandomSampler(train_ids_gta)
    trainloader_gta = data.DataLoader(train_dataset_gta,
                                      batch_size=batch_size_labeled,
                                      sampler=train_sampler_gta,
                                      num_workers=num_workers,
                                      pin_memory=True)
    trainloader_iter_gta = iter(trainloader_gta)

    train_remain_sampler = data.sampler.SubsetRandomSampler(
        train_ids[partial_size:])
    trainloader_remain = data.DataLoader(train_dataset,
                                         batch_size=batch_size_unlabeled,
                                         sampler=train_remain_sampler,
                                         num_workers=num_workers,
                                         pin_memory=True)
    trainloader_remain_iter = iter(trainloader_remain)

    # LOSSES
    unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda()
    supervised_loss = CrossEntropy2d(ignore_label=ignore_label).cuda()
    ''' Deeplab model '''
    # Define network
    if deeplabv2:
        if pretraining == 'COCO':  # coco and iamgenet resnet architectures differ a little, just on how to do the stride
            from model.deeplabv2 import Res_Deeplab
        else:  # imagenet pretrained (more modern modification)
            from model.deeplabv2_imagenet import Res_Deeplab

    else:
        from model.deeplabv3 import Res_Deeplab

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    if pretraining == 'COCO':
        saved_state_dict = model_zoo.load_url(
            'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth'
        )  # COCO pretraining
    else:
        saved_state_dict = model_zoo.load_url(
            'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
        )  # iamgenet pretrainning

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])

    model.load_state_dict(new_params)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    optimizer = torch.optim.SGD(model.optim_parameters(learning_rate_object),
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    ema_model = create_ema_model(model, Res_Deeplab)
    ema_model.train()
    ema_model = ema_model.cuda()
    model.train()
    model = model.cuda()
    cudnn.benchmark = True

    # checkpoint = torch.load('/home/snowflake/Escritorio/Semi-Sup/saved/Deep_cont/best_model.pth')
    # model.load_state_dict(checkpoint['model'])

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    epochs_since_start = 0
    start_iteration = 0
    best_mIoU = 0  # best metric while training
    iters_without_improve = 0

    # TRAINING
    for i_iter in range(start_iteration, num_iterations):
        model.train()  # set mode to training
        optimizer.zero_grad()
        a = time.time()

        loss_l_value = 0.
        adjust_learning_rate(optimizer, i_iter)
        ''' LABELED SAMPLES '''
        # Get batch
        is_cityscapes = i_iter % 2 == 0
        is_gta = not is_cityscapes
        if num_iterations - i_iter > 100:
            # Last 100 itereations only citysacpes data
            is_cityscapes = True

        if is_cityscapes:
            try:
                batch = next(trainloader_iter)
                if batch[0].shape[0] != batch_size_labeled:
                    batch = next(trainloader_iter)
            except:  # finish epoch, rebuild the iterator
                epochs_since_start = epochs_since_start + 1
                # print('Epochs since start: ',epochs_since_start)
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)
        else:
            try:
                batch = next(trainloader_iter_gta)
                if batch[0].shape[0] != batch_size_labeled:
                    train_ids_gta = np.arange(len(train_dataset_gta))
                    np.random.shuffle(train_ids_gta)
                    train_sampler_gta = data.sampler.SubsetRandomSampler(
                        train_ids_gta)
                    trainloader_gta = data.DataLoader(
                        train_dataset_gta,
                        batch_size=batch_size_labeled,
                        sampler=train_sampler_gta,
                        num_workers=num_workers,
                        pin_memory=True)
                    trainloader_iter_gta = iter(trainloader_gta)
                    batch = next(trainloader_iter_gta)
            except:  # finish epoch, rebuild the iterator
                # print('Epochs since start: ',epochs_since_start)
                trainloader_iter_gta = iter(trainloader_gta)
                batch = next(trainloader_iter_gta)

        images, labels, _, _, _ = batch
        images = images.cuda()
        labels = labels.cuda()
        ''' UNLABELED SAMPLES '''
        try:
            batch_remain = next(trainloader_remain_iter)
            if batch_remain[0].shape[0] != batch_size_unlabeled:
                batch_remain = next(trainloader_remain_iter)
        except:
            trainloader_remain_iter = iter(trainloader_remain)
            batch_remain = next(trainloader_remain_iter)

        # Unlabeled
        unlabeled_images, _, _, _, _ = batch_remain
        unlabeled_images = unlabeled_images.cuda()

        # Create pseudolabels
        with torch.no_grad():
            if use_teacher:
                logits_u_w, features_weak_unlabeled = ema_model(
                    normalize(unlabeled_images, dataset), return_features=True)
            else:
                model.eval()
                logits_u_w, features_weak_unlabeled = model(
                    normalize(unlabeled_images, dataset), return_features=True)
            logits_u_w = interp(logits_u_w).detach()  # prediction unlabeled
            softmax_u_w = torch.softmax(logits_u_w, dim=1)
            max_probs, pseudo_label = torch.max(softmax_u_w,
                                                dim=1)  # Get pseudolabels

        model.train()

        if is_cityscapes:
            class_weights_curr.add_frequencies(labels.cpu().numpy(),
                                               pseudo_label.cpu().numpy())

        images2, labels2, _, _ = augment_samples(images,
                                                 labels,
                                                 None,
                                                 random.random() < 0.25,
                                                 batch_size_labeled,
                                                 ignore_label,
                                                 weak=True)
        '''
        UNLABELED DATA
        '''
        '''
        CROSS ENTROPY FOR UNLABELED USING PSEUDOLABELS
        Once you have the speudolabel, perform strong augmetnation to force the netowrk to yield lower confidence scores for pushing them up
        '''

        do_classmix = i_iter > RAMP_UP_ITERS and random.random(
        ) < 0.75  # only after rampup perfrom classmix
        unlabeled_images_aug1, pseudo_label1, max_probs1, unlabeled_aug1_params = augment_samples(
            unlabeled_images, pseudo_label, max_probs, do_classmix,
            batch_size_unlabeled, ignore_label)

        do_classmix = i_iter > RAMP_UP_ITERS and random.random(
        ) < 0.75  # only after rampup perfrom classmix

        unlabeled_images_aug2, pseudo_label2, max_probs2, unlabeled_aug2_params = augment_samples(
            unlabeled_images, pseudo_label, max_probs, do_classmix,
            batch_size_unlabeled, ignore_label)

        joined_unlabeled = torch.cat(
            (unlabeled_images_aug1, unlabeled_images_aug2), dim=0)
        joined_pseudolabels = torch.cat((pseudo_label1, pseudo_label2), dim=0)
        joined_maxprobs = torch.cat((max_probs1, max_probs2), dim=0)

        pred_joined_unlabeled, features_joined_unlabeled = model(
            normalize(joined_unlabeled, dataset), return_features=True)
        pred_joined_unlabeled = interp(pred_joined_unlabeled)

        joined_labeled = images2
        joined_labels = labels2
        labeled_pred, labeled_features = model(normalize(
            joined_labeled, dataset),
                                               return_features=True)
        labeled_pred = interp(labeled_pred)

        class_weights = torch.from_numpy(np.ones((num_classes))).cuda()
        if i_iter > RAMP_UP_ITERS:
            class_weights = torch.from_numpy(
                class_weights_curr.get_weights(num_iterations,
                                               only_labeled=False)).cuda()

        loss = 0
        # SUPERVISED SEGMENTATION
        labeled_loss = supervised_loss(labeled_pred,
                                       joined_labels,
                                       weight=class_weights.float())  #
        loss = loss + labeled_loss

        # SELF-SUPERVISED SEGMENTATION
        '''
        Cross entropy loss using pseudolabels. 
        '''

        unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
            ignore_index=ignore_label, weight=class_weights.float()).cuda()  #

        # Pseudo-label weighting
        pixelWiseWeight = sigmoid_ramp_up(i_iter, RAMP_UP_ITERS) * torch.ones(
            joined_maxprobs.shape).cuda()
        pixelWiseWeight = pixelWiseWeight * torch.pow(joined_maxprobs.detach(),
                                                      6)

        # Pseudo-label loss
        loss_ce_unlabeled = unlabeled_loss(pred_joined_unlabeled,
                                           joined_pseudolabels,
                                           pixelWiseWeight)

        loss = loss + loss_ce_unlabeled

        # entropy loss
        valid_mask = (joined_pseudolabels != ignore_label).unsqueeze(1)
        loss = loss + entropy_loss(
            torch.nn.functional.softmax(pred_joined_unlabeled, dim=1),
            valid_mask) * 0.01

        # CONTRASTIVE LEARNING
        if is_cityscapes:
            if i_iter > RAMP_UP_ITERS - 1000:
                # Build Memory Bank 1000 iters before starting to do contrsative
                with torch.no_grad():
                    if use_teacher:
                        labeled_pred_ema, labeled_features_ema = ema_model(
                            normalize(joined_labeled, dataset),
                            return_features=True)
                    else:
                        model.eval()
                        labeled_pred_ema, labeled_features_ema = model(
                            normalize(joined_labeled, dataset),
                            return_features=True)
                        model.train()

                    labeled_pred_ema = interp(labeled_pred_ema)
                    probability_prediction_ema, label_prediction_ema = torch.max(
                        torch.softmax(labeled_pred_ema,
                                      dim=1), dim=1)  # Get pseudolabels

                labels_down = nn.functional.interpolate(
                    joined_labels.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                label_prediction_down = nn.functional.interpolate(
                    label_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                probability_prediction_down = nn.functional.interpolate(
                    probability_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)

                # get mask where the labeled predictions are correct
                mask_prediction_correctly = (
                    (label_prediction_down == labels_down).float() *
                    (probability_prediction_down > 0.95).float()).bool()

                labeled_features_correct = labeled_features_ema.permute(
                    0, 2, 3, 1)
                labels_down_correct = labels_down[mask_prediction_correctly]
                labeled_features_correct = labeled_features_correct[
                    mask_prediction_correctly, ...]

                # get projected features
                with torch.no_grad():
                    if use_teacher:
                        proj_labeled_features_correct = ema_model.projection_head(
                            labeled_features_correct)
                    else:
                        model.eval()
                        proj_labeled_features_correct = model.projection_head(
                            labeled_features_correct)
                        model.train()

                # updated memory bank
                feature_memory.add_features_from_sample_learned(
                    ema_model, proj_labeled_features_correct,
                    labels_down_correct, batch_size_labeled)

        if i_iter > RAMP_UP_ITERS:
            '''
            LABELED TO LABELED. Force features from laeled samples, to be similar to other features from the same class (which also leads to good predictions)

            '''

            # now we can take all. as they are not the prototypes, here we are gonan force these features to be similar as the correct ones
            mask_prediction_correctly = (labels_down != ignore_label)

            labeled_features_all = labeled_features.permute(0, 2, 3, 1)
            labels_down_all = labels_down[mask_prediction_correctly]
            labeled_features_all = labeled_features_all[
                mask_prediction_correctly, ...]

            # get prediction features
            proj_labeled_features_all = model.projection_head(
                labeled_features_all)
            pred_labeled_features_all = model.prediction_head(
                proj_labeled_features_all)

            loss_contr_labeled = contrastive_class_to_class_learned_memory(
                model, pred_labeled_features_all, labels_down_all, num_classes,
                feature_memory.memory)

            loss = loss + loss_contr_labeled * 0.2
            '''
            CONTRASTIVE LEARNING ON UNLABELED DATA. align unlabeled features to labeled features
            '''

            joined_pseudolabels_down = nn.functional.interpolate(
                joined_pseudolabels.float().unsqueeze(1),
                size=(features_joined_unlabeled.shape[2],
                      features_joined_unlabeled.shape[3]),
                mode='nearest').squeeze(1)

            # take out the features from black pixels from zooms out and augmetnations (ignore labels on pseduoalebl)
            mask = (joined_pseudolabels_down != ignore_label)

            features_joined_unlabeled = features_joined_unlabeled.permute(
                0, 2, 3, 1)
            features_joined_unlabeled = features_joined_unlabeled[mask, ...]
            joined_pseudolabels_down = joined_pseudolabels_down[mask]

            # get projected features
            proj_feat_unlabeled = model.projection_head(
                features_joined_unlabeled)
            pred_feat_unlabeled = model.prediction_head(proj_feat_unlabeled)

            loss_contr_unlabeled = contrastive_class_to_class_learned_memory(
                model, pred_feat_unlabeled, joined_pseudolabels_down,
                num_classes, feature_memory.memory)

            loss = loss + loss_contr_unlabeled * 0.2

        loss_l_value += loss.item()

        # optimize
        loss.backward()
        optimizer.step()

        m = 1 - (1 - 0.995) * (math.cos(math.pi * i_iter / num_iterations) +
                               1) / 2
        ema_model = update_ema_variables(ema_model=ema_model,
                                         model=model,
                                         alpha_teacher=m,
                                         iteration=i_iter)

        # print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format(i_iter, num_iterations, loss_l_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config)

        if i_iter % val_per_iter == 0 and i_iter != 0:
            print('iter = {0:6d}/{1:6d}'.format(i_iter, num_iterations))

            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       ignore_label=ignore_label,
                                       save_dir=checkpoint_dir,
                                       pretraining=pretraining)
            model.train()

            if mIoU > best_mIoU:
                best_mIoU = mIoU
                if save_teacher:
                    _save_checkpoint(i_iter,
                                     ema_model,
                                     optimizer,
                                     config,
                                     save_best=True)
                else:
                    _save_checkpoint(i_iter,
                                     model,
                                     optimizer,
                                     config,
                                     save_best=True)
                iters_without_improve = 0
            else:
                iters_without_improve += val_per_iter
            '''
            if the performance has not improve in N iterations, try to reload best model to optimize again with a lower LR
            Simulating an iterative training'''
            if iters_without_improve > num_iterations / 5.:
                print('Re-loading a previous best model')
                checkpoint = torch.load(
                    os.path.join(checkpoint_dir, f'best_model.pth'))
                model.load_state_dict(checkpoint['model'])
                ema_model = create_ema_model(model, Res_Deeplab)
                ema_model.train()
                ema_model = ema_model.cuda()
                model.train()
                model = model.cuda()
                iters_without_improve = 0  # reset timer

    _save_checkpoint(num_iterations, model, optimizer, config)

    # FINISH TRAINING, evaluate again
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()

    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    # TRY IMPROVING BEST MODEL WITH EMA MODEL OR UPDATING BN STATS

    # Load best model
    checkpoint = torch.load(os.path.join(checkpoint_dir, f'best_model.pth'))
    model.load_state_dict(checkpoint['model'])
    model = model.cuda()

    model = update_BN_weak_unlabeled_data(model, normalize,
                                          batch_size_unlabeled,
                                          trainloader_remain)
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    print('BEST MIOU')
    print(max(best_mIoU_improved, best_mIoU))

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')