Example #1
0
def main():
    """Create the model and start the evaluation process."""

    gpu0 = args.gpu

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    #model = torch.nn.DataParallel(Res_Deeplab(num_classes=num_classes), device_ids=args.gpu)
    model = Res_Deeplab(num_classes=num_classes)

    checkpoint = torch.load(args.model_path)
    try:
        model.load_state_dict(checkpoint['model'])
    except:
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
        model.load_state_dict(checkpoint['model'])

    model.cuda()
    model.eval()

    evaluate(model,
             dataset,
             ignore_label=ignore_label,
             save_output_images=args.save_output_images,
             save_dir=save_dir,
             input_size=input_size)
Example #2
0
def main():
    """Create the model and start the evaluation process."""

    deeplabv2 = "2" in config['version']

    if deeplabv2:
        if pretraining == 'COCO':  # coco and iamgenet resnet architectures differ a little, just on how to do the stride
            from model.deeplabv2 import Res_Deeplab
        else:  # imagenet pretrained (more modern modification)
            from model.deeplabv2_imagenet import Res_Deeplab

    else:
        from model.deeplabv3 import Res_Deeplab

    model = Res_Deeplab(num_classes=num_classes)

    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['model'])

    model = model.cuda()
    model.eval()

    evaluate(model,
             dataset,
             deeplabv2=deeplabv2,
             ignore_label=ignore_label,
             pretraining=pretraining)
Example #3
0
def main():
    torch.cuda.empty_cache()
    print(config)

    best_mIoU = 0

    if consistency_loss == 'CE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(
                CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label),
                device_ids=gpus).cuda()
        else:
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda()
    elif consistency_loss == 'MSE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(MSELoss2d(),
                                                   device_ids=gpus).cuda()
        else:
            unlabeled_loss = MSELoss2d().cuda()

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # Initiate ema-model
    if train_unlabeled:
        ema_model = create_ema_model(model)
        ema_model.train()
        ema_model = ema_model.cuda()
    else:
        ema_model = None

    if len(gpus) > 1:
        if use_sync_batchnorm:
            model = convert_model(model)
            model = DataParallelWithCallback(model, device_ids=gpus)
        else:
            model = torch.nn.DataParallel(model, device_ids=gpus)
    model.train()
    model.cuda()

    cudnn.benchmark = True

    if dataset == 'pascal_voc':
        data_loader = get_loader(dataset)
        data_path = get_data_path(dataset)
        train_dataset = data_loader(data_path,
                                    crop_size=input_size,
                                    scale=random_scale,
                                    mirror=random_flip)

    elif dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        if random_crop:
            data_aug = Compose([RandomCrop_city(input_size)])
        else:
            data_aug = None

        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        np.random.seed(random_seed)
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    if train_unlabeled:
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=1,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    if optimizer_type == 'SGD':
        if len(gpus) > 1:
            optimizer = optim.SGD(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.SGD(
                model.optim_parameters(
                    learning_rate_object),  ## DOES THIS CAUSE THE USERWARNING?
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)

    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    start_iteration = 0

    if args.resume:
        start_iteration, model, optimizer, ema_model = _resume_checkpoint(
            args.resume, model, optimizer, ema_model)

    accumulated_loss_l = []
    if train_unlabeled:
        accumulated_loss_u = []

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    epochs_since_start = 0
    for i_iter in range(start_iteration, num_iterations):
        model.train()

        loss_l_value = 0
        if train_unlabeled:
            loss_u_value = 0

        optimizer.zero_grad()

        if lr_schedule:
            adjust_learning_rate(optimizer, i_iter)

        # Training loss for labeled data only
        try:
            batch = next(trainloader_iter)
            if batch[0].shape[0] != batch_size:
                batch = next(trainloader_iter)
        except:
            epochs_since_start = epochs_since_start + 1
            print('Epochs since start: ', epochs_since_start)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        weak_parameters = {"flip": 0}

        images, labels, _, _, _ = batch
        images = images.cuda()
        labels = labels.cuda()

        images, labels = weakTransform(weak_parameters,
                                       data=images,
                                       target=labels)
        intermediary_var = model(images)
        pred = interp(intermediary_var)

        L_l = loss_calc(pred, labels)

        if train_unlabeled:
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            images_remain, _, _, _, _ = batch_remain
            images_remain = images_remain.cuda()
            inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain)
            logits_u_w = interp(ema_model(inputs_u_w))
            logits_u_w, _ = weakTransform(
                getWeakInverseTransformParameters(weak_parameters),
                data=logits_u_w.detach())

            softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1)
            max_probs, argmax_u_w = torch.max(softmax_u_w, dim=1)

            if mix_mask == "class":

                for image_i in range(batch_size):
                    classes = torch.unique(argmax_u_w[image_i])
                    classes = classes[classes != ignore_label]
                    nclasses = classes.shape[0]
                    classes = (classes[torch.Tensor(
                        np.random.choice(nclasses,
                                         int((nclasses - nclasses % 2) / 2),
                                         replace=False)).long()]).cuda()
                    if image_i == 0:
                        MixMask = transformmasks.generate_class_mask(
                            argmax_u_w[image_i], classes).unsqueeze(0).cuda()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             transformmasks.generate_class_mask(
                                 argmax_u_w[image_i],
                                 classes).unsqueeze(0).cuda()))

            elif mix_mask == 'cut':
                img_size = inputs_u_w.shape[2:4]
                for image_i in range(batch_size):
                    if image_i == 0:
                        MixMask = torch.from_numpy(
                            transformmasks.generate_cutout_mask(
                                img_size)).unsqueeze(0).cuda().float()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             torch.from_numpy(
                                 transformmasks.generate_cutout_mask(
                                     img_size)).unsqueeze(0).cuda().float()))

            elif mix_mask == "cow":
                img_size = inputs_u_w.shape[2:4]
                sigma_min = 8
                sigma_max = 32
                p_min = 0.5
                p_max = 0.5
                for image_i in range(batch_size):
                    sigma = np.exp(
                        np.random.uniform(np.log(sigma_min),
                                          np.log(sigma_max)))  # Random sigma
                    p = np.random.uniform(p_min, p_max)  # Random p
                    if image_i == 0:
                        MixMask = torch.from_numpy(
                            transformmasks.generate_cow_mask(
                                img_size, sigma, p,
                                seed=None)).unsqueeze(0).cuda().float()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             torch.from_numpy(
                                 transformmasks.generate_cow_mask(
                                     img_size, sigma, p,
                                     seed=None)).unsqueeze(0).cuda().float()))

            elif mix_mask == None:
                MixMask = torch.ones((inputs_u_w.shape)).cuda()

            strong_parameters = {"Mix": MixMask}
            if random_flip:
                strong_parameters["flip"] = random.randint(0, 1)
            else:
                strong_parameters["flip"] = 0
            if color_jitter:
                strong_parameters["ColorJitter"] = random.uniform(0, 1)
            else:
                strong_parameters["ColorJitter"] = 0
            if gaussian_blur:
                strong_parameters["GaussianBlur"] = random.uniform(0, 1)
            else:
                strong_parameters["GaussianBlur"] = 0

            inputs_u_s, _ = strongTransform(strong_parameters,
                                            data=images_remain)
            logits_u_s = interp(model(inputs_u_s))

            softmax_u_w_mixed, _ = strongTransform(strong_parameters,
                                                   data=softmax_u_w)
            max_probs, pseudo_label = torch.max(softmax_u_w_mixed, dim=1)

            if pixel_weight == "threshold_uniform":
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(pseudo_label.cpu()))
                pixelWiseWeight = unlabeled_weight * torch.ones(
                    max_probs.shape).cuda()
            elif pixel_weight == "threshold":
                pixelWiseWeight = max_probs.ge(0.968).long().cuda()
            elif pixel_weight == 'sigmoid':
                max_iter = 10000
                pixelWiseWeight = sigmoid_ramp_up(
                    i_iter, max_iter) * torch.ones(max_probs.shape).cuda()
            elif pixel_weight == False:
                pixelWiseWeight = torch.ones(max_probs.shape).cuda()

            if consistency_loss == 'CE':
                L_u = consistency_weight * unlabeled_loss(
                    logits_u_s, pseudo_label, pixelWiseWeight)
            elif consistency_loss == 'MSE':
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(pseudo_label.cpu()))
                #softmax_u_w_mixed = torch.cat((softmax_u_w_mixed[1].unsqueeze(0),softmax_u_w_mixed[0].unsqueeze(0)))
                L_u = consistency_weight * unlabeled_weight * unlabeled_loss(
                    logits_u_s, softmax_u_w_mixed)

            loss = L_l + L_u

        else:
            loss = L_l

        if len(gpus) > 1:
            loss = loss.mean()
            loss_l_value += L_l.mean().item()
            if train_unlabeled:
                loss_u_value += L_u.mean().item()
        else:
            loss_l_value += L_l.item()
            if train_unlabeled:
                loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()

        # update Mean teacher network
        if ema_model is not None:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model=ema_model,
                                             model=model,
                                             alpha_teacher=alpha_teacher,
                                             iteration=i_iter)

        if train_unlabeled:
            print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.
                  format(i_iter, num_iterations, loss_l_value, loss_u_value))
        else:
            print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format(
                i_iter, num_iterations, loss_l_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config, ema_model)

        if use_tensorboard:
            if 'tensorboard_writer' not in locals():
                tensorboard_writer = tensorboard.SummaryWriter(log_dir,
                                                               flush_secs=30)

            accumulated_loss_l.append(loss_l_value)
            if train_unlabeled:
                accumulated_loss_u.append(loss_u_value)
            if i_iter % log_per_iter == 0 and i_iter != 0:

                tensorboard_writer.add_scalar('Training/Supervised loss',
                                              np.mean(accumulated_loss_l),
                                              i_iter)
                accumulated_loss_l = []

                if train_unlabeled:
                    tensorboard_writer.add_scalar('Training/Unsupervised loss',
                                                  np.mean(accumulated_loss_u),
                                                  i_iter)
                    accumulated_loss_u = []

        if i_iter % val_per_iter == 0 and i_iter != 0:
            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       ignore_label=ignore_label,
                                       input_size=(512, 1024),
                                       save_dir=checkpoint_dir)

            model.train()

            if mIoU > best_mIoU and save_best_model:
                best_mIoU = mIoU
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 save_best=True)

            if use_tensorboard:
                tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
                tensorboard_writer.add_scalar('Validation/Loss', eval_loss,
                                              i_iter)

        if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0:
            # Saves two mixed images and the corresponding prediction
            save_image(inputs_u_s[0].cpu(), i_iter, 'input1',
                       palette.CityScpates_palette)
            save_image(inputs_u_s[1].cpu(), i_iter, 'input2',
                       palette.CityScpates_palette)
            _, pred_u_s = torch.max(logits_u_s, dim=1)
            save_image(pred_u_s[0].cpu(), i_iter, 'pred1',
                       palette.CityScpates_palette)
            save_image(pred_u_s[1].cpu(), i_iter, 'pred2',
                       palette.CityScpates_palette)

    _save_checkpoint(num_iterations, model, optimizer, config, ema_model)

    model.eval()
    mIoU, val_loss = evaluate(model,
                              dataset,
                              ignore_label=ignore_label,
                              input_size=(512, 1024),
                              save_dir=checkpoint_dir)

    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter,
                         model,
                         optimizer,
                         config,
                         ema_model,
                         save_best=True)

    if use_tensorboard:
        tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
        tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter)

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')
def main():
    print(config)
    cudnn.enabled = True
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.backends.cudnn.deterministic = True

    if pretraining == 'COCO':  # depending the pretraining, normalize with bgr or rgb
        from utils.transformsgpu import normalize_bgr as normalize
    else:
        from utils.transformsgpu import normalize_rgb as normalize

    batch_size_unlabeled = int(
        batch_size /
        2)  # because of augmentation anchoring, 2 augmentations per sample
    batch_size_labeled = int(batch_size * 1)
    assert batch_size_unlabeled >= 2, "batch size should be higher than 2"
    assert batch_size_labeled >= 2, "batch size should be higher than 2"
    RAMP_UP_ITERS = 2000  # iterations until contrastive and self-training are taken into account

    # DATASETS / LOADERS
    if dataset == 'pascal_voc':
        data_loader = get_loader(dataset)
        data_path = get_data_path(dataset)
        train_dataset = data_loader(data_path,
                                    crop_size=input_size,
                                    scale=False,
                                    mirror=False,
                                    pretraining=pretraining)

    elif dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        if deeplabv2:
            data_aug = Compose([RandomCrop_city(input_size)])
        else:  # for deeplabv3 original resolution
            data_aug = Compose([RandomCrop_city_highres(input_size)])
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    pretraining=pretraining)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)

    # class weighting  taken unlabeled data into acount in an incremental fashion.
    class_weights_curr = ClassBalancing(
        labeled_iters=int(labeled_samples / batch_size_labeled),
        unlabeled_iters=int(
            (train_dataset_size - labeled_samples) / batch_size_unlabeled),
        n_classes=num_classes)
    # Memory Bank
    feature_memory = FeatureMemory(num_samples=labeled_samples,
                                   dataset=dataset,
                                   memory_per_class=256,
                                   feature_size=256,
                                   n_classes=num_classes)

    # select the partition
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    # Samplers for labeled data
    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size_labeled,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    # Samplers for unlabeled data
    train_remain_sampler = data.sampler.SubsetRandomSampler(
        train_ids[partial_size:])
    trainloader_remain = data.DataLoader(train_dataset,
                                         batch_size=batch_size_unlabeled,
                                         sampler=train_remain_sampler,
                                         num_workers=num_workers,
                                         pin_memory=True)
    trainloader_remain_iter = iter(trainloader_remain)

    # supervised loss
    supervised_loss = CrossEntropy2d(ignore_label=ignore_label).cuda()
    ''' Deeplab model '''
    # Define network
    if deeplabv2:
        if pretraining == 'COCO':  # coco and imagenet resnet architectures differ a little, just on how to do the stride
            from model.deeplabv2 import Res_Deeplab
        else:  # imagenet pretrained (more modern modification)
            from model.deeplabv2_imagenet import Res_Deeplab

        # load pretrained parameters
        if pretraining == 'COCO':
            saved_state_dict = model_zoo.load_url(
                'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth'
            )  # COCO pretraining
        else:
            saved_state_dict = model_zoo.load_url(
                'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
            )  # iamgenet pretrainning

    else:
        from model.deeplabv3 import Res_Deeplab50 as Res_Deeplab
        saved_state_dict = model_zoo.load_url(
            'https://download.pytorch.org/models/resnet50-19c8e357.pth'
        )  # iamgenet pretrainning

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    optimizer = torch.optim.SGD(model.optim_parameters(learning_rate_object),
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    ema_model = create_ema_model(model, Res_Deeplab)
    ema_model.train()
    ema_model = ema_model.cuda()
    model.train()
    model = model.cuda()
    cudnn.benchmark = True

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    epochs_since_start = 0
    start_iteration = 0
    best_mIoU = 0  # best metric while training
    iters_without_improve = 0

    # TRAINING
    for i_iter in range(start_iteration, num_iterations):
        model.train()  # set mode to training
        optimizer.zero_grad()

        loss_l_value = 0.
        adjust_learning_rate(optimizer, i_iter)

        labeled_turn = i_iter % 2 == 0

        if labeled_turn:  # labeled data optimization
            ''' LABELED SAMPLES '''
            # Get batch
            try:
                batch = next(trainloader_iter)
                if batch[0].shape[0] != batch_size_labeled:
                    batch = next(trainloader_iter)
            except:  # finish epoch, rebuild the iterator
                epochs_since_start = epochs_since_start + 1
                # print('Epochs since start: ',epochs_since_start)
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)

            images, labels, _, _, _ = batch
            images = images.cuda()
            labels = labels.cuda()

            model.train()

            if dataset == 'cityscapes':
                class_weights_curr.add_frequencies_labeled(
                    labels.cpu().numpy())

            images_aug, labels_aug, _, _ = augment_samples(
                images,
                labels,
                None,
                random.random() < 0.2,
                batch_size_labeled,
                ignore_label,
                weak=True)

            # labeled data
            labeled_pred, labeled_features = model(normalize(
                images_aug, dataset),
                                                   return_features=True)
            labeled_pred = interp(labeled_pred)

            # apply class balance for cityspcaes dataset
            class_weights = torch.from_numpy(np.ones((num_classes))).cuda()
            if i_iter > RAMP_UP_ITERS and dataset == 'cityscapes':
                class_weights = torch.from_numpy(
                    class_weights_curr.get_weights(num_iterations,
                                                   only_labeled=False)).cuda()

            loss = 0

            # SUPERVISED SEGMENTATION
            labeled_loss = supervised_loss(labeled_pred,
                                           labels_aug,
                                           weight=class_weights.float())
            loss = loss + labeled_loss

            # CONTRASTIVE LEARNING
            if i_iter > RAMP_UP_ITERS - 1000:
                # Build Memory Bank 1000 iters before starting to do contrsative

                with torch.no_grad():
                    # Get feature vectors from labeled images with EMA model
                    if use_teacher:
                        labeled_pred_ema, labeled_features_ema = ema_model(
                            normalize(images_aug, dataset),
                            return_features=True)
                    else:
                        model.eval()
                        labeled_pred_ema, labeled_features_ema = model(
                            normalize(images_aug, dataset),
                            return_features=True)
                        model.train()
                    labeled_pred_ema = interp(labeled_pred_ema)
                    probability_prediction_ema, label_prediction_ema = torch.max(
                        torch.softmax(labeled_pred_ema,
                                      dim=1), dim=1)  # Get pseudolabels

                # Resize labels, predictions and probabilities,  to feature map resolution
                labels_down = nn.functional.interpolate(
                    labels_aug.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                label_prediction_down = nn.functional.interpolate(
                    label_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                probability_prediction_down = nn.functional.interpolate(
                    probability_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)

                # get mask where the labeled predictions are correct and have a confidence higher than 0.95
                mask_prediction_correctly = (
                    (label_prediction_down == labels_down).float() *
                    (probability_prediction_down > 0.95).float()).bool()

                # Apply the filter mask to the features and its labels
                labeled_features_correct = labeled_features_ema.permute(
                    0, 2, 3, 1)
                labels_down_correct = labels_down[mask_prediction_correctly]
                labeled_features_correct = labeled_features_correct[
                    mask_prediction_correctly, ...]

                # get projected features
                with torch.no_grad():
                    if use_teacher:
                        proj_labeled_features_correct = ema_model.projection_head(
                            labeled_features_correct)
                    else:
                        model.eval()
                        proj_labeled_features_correct = model.projection_head(
                            labeled_features_correct)
                        model.train()

                # updated memory bank
                feature_memory.add_features_from_sample_learned(
                    ema_model, proj_labeled_features_correct,
                    labels_down_correct, batch_size_labeled)

            if i_iter > RAMP_UP_ITERS:
                '''
                CONTRASTIVE LEARNING ON LABELED DATA. Force features from labeled samples, to be similar to other features from the same class (which also leads to good predictions
                '''
                # mask features that do not have ignore label in the labels (zero-padding because of data augmentation like resize/crop)
                mask_prediction_correctly = (labels_down != ignore_label)

                labeled_features_all = labeled_features.permute(0, 2, 3, 1)
                labels_down_all = labels_down[mask_prediction_correctly]
                labeled_features_all = labeled_features_all[
                    mask_prediction_correctly, ...]

                # get predicted features
                proj_labeled_features_all = model.projection_head(
                    labeled_features_all)
                pred_labeled_features_all = model.prediction_head(
                    proj_labeled_features_all)

                # Apply contrastive learning loss
                loss_contr_labeled = contrastive_class_to_class_learned_memory(
                    model, pred_labeled_features_all, labels_down_all,
                    num_classes, feature_memory.memory)

                loss = loss + loss_contr_labeled * 0.1

        else:  # unlabeled data optimization
            ''' UNLABELED SAMPLES '''
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size_unlabeled:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            # Unlabeled
            unlabeled_images, _, _, _, _ = batch_remain
            unlabeled_images = unlabeled_images.cuda()

            # Create pseudolabels
            with torch.no_grad():
                if use_teacher:
                    logits_u_w, features_weak_unlabeled = ema_model(
                        normalize(unlabeled_images, dataset),
                        return_features=True)
                else:
                    model.eval()
                    logits_u_w, features_weak_unlabeled = model(
                        normalize(unlabeled_images, dataset),
                        return_features=True)
                logits_u_w = interp(
                    logits_u_w).detach()  # prediction unlabeled
                softmax_u_w = torch.softmax(logits_u_w, dim=1)
                max_probs, pseudo_label = torch.max(softmax_u_w,
                                                    dim=1)  # Get pseudolabels

            model.train()

            if dataset == 'cityscapes':
                class_weights_curr.add_frequencies_unlabeled(
                    pseudo_label.cpu().numpy())
            '''
            UNLABELED DATA
            '''
            unlabeled_images_aug1, pseudo_label1, max_probs1, unlabeled_aug1_params = augment_samples(
                unlabeled_images, pseudo_label, max_probs,
                i_iter > RAMP_UP_ITERS and random.random() < 0.75,
                batch_size_unlabeled, ignore_label)

            unlabeled_images_aug2, pseudo_label2, max_probs2, unlabeled_aug2_params = augment_samples(
                unlabeled_images, pseudo_label, max_probs,
                i_iter > RAMP_UP_ITERS and random.random() < 0.75,
                batch_size_unlabeled, ignore_label)
            # concatenate two augmentations of unlabeled data
            joined_unlabeled = torch.cat(
                (unlabeled_images_aug1, unlabeled_images_aug2), dim=0)
            joined_pseudolabels = torch.cat((pseudo_label1, pseudo_label2),
                                            dim=0)
            joined_maxprobs = torch.cat((max_probs1, max_probs2), dim=0)

            pred_joined_unlabeled, features_joined_unlabeled = model(
                normalize(joined_unlabeled, dataset), return_features=True)
            pred_joined_unlabeled = interp(pred_joined_unlabeled)

            # apply clas balance for cityspcaes dataset
            if dataset == 'cityscapes':
                class_weights = torch.from_numpy(
                    class_weights_curr.get_weights(num_iterations,
                                                   only_labeled=False)).cuda()
            else:
                class_weights = torch.from_numpy(np.ones((num_classes))).cuda()

            loss = 0

            # SELF-SUPERVISED SEGMENTATION
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
                ignore_index=ignore_label,
                weight=class_weights.float()).cuda()  #

            # Pseudo-label weighting
            pixelWiseWeight = sigmoid_ramp_up(
                i_iter, RAMP_UP_ITERS) * torch.ones(
                    joined_maxprobs.shape).cuda()
            pixelWiseWeight = pixelWiseWeight * torch.pow(
                joined_maxprobs.detach(), 6)

            # Pseudo-label loss
            loss_ce_unlabeled = unlabeled_loss(pred_joined_unlabeled,
                                               joined_pseudolabels,
                                               pixelWiseWeight)

            loss = loss + loss_ce_unlabeled

            # entropy loss
            valid_mask = (joined_pseudolabels != ignore_label).unsqueeze(1)
            loss = loss + entropy_loss(
                torch.nn.functional.softmax(pred_joined_unlabeled, dim=1),
                valid_mask) * 0.01

            if i_iter > RAMP_UP_ITERS:
                '''
                CONTRASTIVE LEARNING ON UNLABELED DATA. align unlabeled features to labeled features
                '''
                joined_pseudolabels_down = nn.functional.interpolate(
                    joined_pseudolabels.float().unsqueeze(1),
                    size=(features_joined_unlabeled.shape[2],
                          features_joined_unlabeled.shape[3]),
                    mode='nearest').squeeze(1)

                # mask features that do not have ignore label in the labels (zero-padding because of data augmentation like resize/crop)
                mask = (joined_pseudolabels_down != ignore_label)

                features_joined_unlabeled = features_joined_unlabeled.permute(
                    0, 2, 3, 1)
                features_joined_unlabeled = features_joined_unlabeled[mask,
                                                                      ...]
                joined_pseudolabels_down = joined_pseudolabels_down[mask]

                # get predicted features
                proj_feat_unlabeled = model.projection_head(
                    features_joined_unlabeled)
                pred_feat_unlabeled = model.prediction_head(
                    proj_feat_unlabeled)

                # Apply contrastive learning loss
                loss_contr_unlabeled = contrastive_class_to_class_learned_memory(
                    model, pred_feat_unlabeled, joined_pseudolabels_down,
                    num_classes, feature_memory.memory)

                loss = loss + loss_contr_unlabeled * 0.1

        # common code

        loss_l_value += loss.item()

        # optimize
        loss.backward()
        optimizer.step()

        m = 1 - (1 - 0.995) * (math.cos(math.pi * i_iter / num_iterations) +
                               1) / 2
        ema_model = update_ema_variables(ema_model=ema_model,
                                         model=model,
                                         alpha_teacher=m,
                                         iteration=i_iter)

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config)

        if i_iter % val_per_iter == 0 and i_iter != 0:
            print('iter = {0:6d}/{1:6d}'.format(i_iter, num_iterations))

            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       deeplabv2=deeplabv2,
                                       ignore_label=ignore_label,
                                       save_dir=checkpoint_dir,
                                       pretraining=pretraining)
            model.train()

            if mIoU > best_mIoU:
                best_mIoU = mIoU
                if save_teacher:
                    _save_checkpoint(i_iter,
                                     ema_model,
                                     optimizer,
                                     config,
                                     save_best=True)
                else:
                    _save_checkpoint(i_iter,
                                     model,
                                     optimizer,
                                     config,
                                     save_best=True)
                iters_without_improve = 0
            else:
                iters_without_improve += val_per_iter
            '''
            if the performance has not improve in N iterations, try to reload best model to optimize again with a lower LR
            Simulating an iterative training'''
            if iters_without_improve > num_iterations / 5.:
                print('Re-loading a previous best model')
                checkpoint = torch.load(
                    os.path.join(checkpoint_dir, f'best_model.pth'))
                model.load_state_dict(checkpoint['model'])
                ema_model = create_ema_model(model, Res_Deeplab)
                ema_model.train()
                ema_model = ema_model.cuda()
                model.train()
                model = model.cuda()
                iters_without_improve = 0  # reset timer

    _save_checkpoint(num_iterations, model, optimizer, config)

    # FINISH TRAINING, evaluate again
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()

    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    # TRY IMPROVING BEST MODEL WITH EMA MODEL OR UPDATING BN STATS

    # Load best model
    checkpoint = torch.load(os.path.join(checkpoint_dir, f'best_model.pth'))
    model.load_state_dict(checkpoint['model'])
    model = model.cuda()

    model = update_BN_weak_unlabeled_data(model, normalize,
                                          batch_size_unlabeled,
                                          trainloader_remain)
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    print('BEST MIOU')
    print(max(best_mIoU_improved, best_mIoU))

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')
Example #5
0
def main():
    print(config)

    best_mIoU = 0

    if consistency_loss == 'MSE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(MSELoss2d(),
                                                   device_ids=gpus).cuda()
        else:
            unlabeled_loss = MSELoss2d().cuda()
    elif consistency_loss == 'CE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(
                CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label),
                device_ids=gpus).cuda()
        else:
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
                ignore_index=ignore_label).cuda()

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    #saved_state_dict = torch.load(args.restore_from)
    # load pretrained parameters
    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # init ema-model
    if train_unlabeled:
        ema_model = create_ema_model(model)
        ema_model.train()
        ema_model = ema_model.cuda()
    else:
        ema_model = None

    if len(gpus) > 1:
        if use_sync_batchnorm:
            model = convert_model(model)
            model = DataParallelWithCallback(model, device_ids=gpus)
        else:
            model = torch.nn.DataParallel(model, device_ids=gpus)
    model.train()
    model.cuda()

    cudnn.benchmark = True
    data_loader = get_loader(config['dataset'])
    # data_path = get_data_path(config['dataset'])
    # if random_crop:
    # data_aug = Compose([RandomCrop_city(input_size)])
    # else:
    # data_aug = None
    data_aug = Compose([RandomHorizontallyFlip()])
    if dataset == 'cityscapes':
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)
    elif dataset == 'multiview':
        # adaption data
        data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation'
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    view_idx=0,
                                    number_views=1,
                                    load_seg_mask=False,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    if labeled_samples is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True)

        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    else:
        partial_size = labeled_samples
        print('Training on number of samples:', partial_size)
        np.random.seed(random_seed)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    #New loader for Domain transfer
    # if random_crop:
    # data_aug = Compose([RandomCrop_gta(input_size)])
    # else:
    # data_aug = None
    # SUPERVSIED DATA
    data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation'
    data_aug = Compose([RandomHorizontallyFlip()])
    if dataset == 'multiview':
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    view_idx=0,
                                    number_views=1,
                                    load_seg_mask=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)
    else:
        data_loader = get_loader('gta')
        data_path = get_data_path('gta')
        train_dataset = data_loader(data_path,
                                    list_path='./data/gta5_list/train.txt',
                                    augmentations=data_aug,
                                    img_size=(1280, 720),
                                    mean=IMG_MEAN)

    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True)

    # training loss for labeled data only
    trainloader_iter = iter(trainloader)
    print('gta size:', len(trainloader))

    #Load new data for domain_transfer

    # optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    if optimizer_type == 'SGD':
        if len(gpus) > 1:
            optimizer = optim.SGD(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.SGD(model.optim_parameters(learning_rate_object),
                                  lr=learning_rate,
                                  momentum=momentum,
                                  weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        if len(gpus) > 1:
            optimizer = optim.Adam(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.Adam(
                model.optim_parameters(learning_rate_object),
                lr=learning_rate,
                weight_decay=weight_decay)

    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)
    start_iteration = 0

    if args.resume:
        start_iteration, model, optimizer, ema_model = _resume_checkpoint(
            args.resume, model, optimizer, ema_model)

    accumulated_loss_l = []
    accumulated_loss_u = []

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=True)

    epochs_since_start = 0
    for i_iter in range(start_iteration, num_iterations):
        model.train()

        loss_u_value = 0
        loss_l_value = 0

        optimizer.zero_grad()

        if lr_schedule:
            adjust_learning_rate(optimizer, i_iter)

        # training loss for labeled data only
        try:
            batch = next(trainloader_iter)
            if batch[0].shape[0] != batch_size:
                batch = next(trainloader_iter)
        except:
            epochs_since_start = epochs_since_start + 1
            print('Epochs since start: ', epochs_since_start)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        #if random_flip:
        #    weak_parameters={"flip":random.randint(0,1)}
        #else:
        weak_parameters = {"flip": 0}

        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.cuda().long()

        #images, labels = weakTransform(weak_parameters, data = images, target = labels)

        pred = interp(model(images))
        L_l = loss_calc(pred, labels)  # Cross entropy loss for labeled data
        #L_l = torch.Tensor([0.0]).cuda()

        if train_unlabeled:
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            images_remain, *_ = batch_remain
            images_remain = images_remain.cuda()
            inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain)
            #inputs_u_w = inputs_u_w.clone()
            logits_u_w = interp(ema_model(inputs_u_w))
            logits_u_w, _ = weakTransform(
                getWeakInverseTransformParameters(weak_parameters),
                data=logits_u_w.detach())

            pseudo_label = torch.softmax(logits_u_w.detach(), dim=1)
            max_probs, targets_u_w = torch.max(pseudo_label, dim=1)

            if mix_mask == "class":
                for image_i in range(batch_size):
                    classes = torch.unique(labels[image_i])
                    #classes=classes[classes!=ignore_label]
                    nclasses = classes.shape[0]
                    #if nclasses > 0:
                    classes = (classes[torch.Tensor(
                        np.random.choice(nclasses,
                                         int((nclasses + nclasses % 2) / 2),
                                         replace=False)).long()]).cuda()

                    if image_i == 0:
                        MixMask0 = transformmasks.generate_class_mask(
                            labels[image_i], classes).unsqueeze(0).cuda()
                    else:
                        MixMask1 = transformmasks.generate_class_mask(
                            labels[image_i], classes).unsqueeze(0).cuda()

            elif mix_mask == None:
                MixMask = torch.ones((inputs_u_w.shape))

            strong_parameters = {"Mix": MixMask0}
            if random_flip:
                strong_parameters["flip"] = random.randint(0, 1)
            else:
                strong_parameters["flip"] = 0
            if color_jitter:
                strong_parameters["ColorJitter"] = random.uniform(0, 1)
            else:
                strong_parameters["ColorJitter"] = 0
            if gaussian_blur:
                strong_parameters["GaussianBlur"] = random.uniform(0, 1)
            else:
                strong_parameters["GaussianBlur"] = 0

            inputs_u_s0, _ = strongTransform(
                strong_parameters,
                data=torch.cat(
                    (images[0].unsqueeze(0), images_remain[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            inputs_u_s1, _ = strongTransform(
                strong_parameters,
                data=torch.cat(
                    (images[1].unsqueeze(0), images_remain[1].unsqueeze(0))))
            inputs_u_s = torch.cat((inputs_u_s0, inputs_u_s1))
            logits_u_s = interp(model(inputs_u_s))

            strong_parameters["Mix"] = MixMask0
            _, targets_u0 = strongTransform(strong_parameters,
                                            target=torch.cat(
                                                (labels[0].unsqueeze(0),
                                                 targets_u_w[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            _, targets_u1 = strongTransform(strong_parameters,
                                            target=torch.cat(
                                                (labels[1].unsqueeze(0),
                                                 targets_u_w[1].unsqueeze(0))))
            targets_u = torch.cat((targets_u0, targets_u1)).long()

            if pixel_weight == "threshold_uniform":
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(targets_u.cpu()))
                pixelWiseWeight = unlabeled_weight * torch.ones(
                    max_probs.shape).cuda()
            elif pixel_weight == "threshold":
                pixelWiseWeight = max_probs.ge(0.968).float().cuda()
            elif pixel_weight == False:
                pixelWiseWeight = torch.ones(max_probs.shape).cuda()

            onesWeights = torch.ones((pixelWiseWeight.shape)).cuda()
            strong_parameters["Mix"] = MixMask0
            _, pixelWiseWeight0 = strongTransform(
                strong_parameters,
                target=torch.cat((onesWeights[0].unsqueeze(0),
                                  pixelWiseWeight[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            _, pixelWiseWeight1 = strongTransform(
                strong_parameters,
                target=torch.cat((onesWeights[1].unsqueeze(0),
                                  pixelWiseWeight[1].unsqueeze(0))))
            pixelWiseWeight = torch.cat(
                (pixelWiseWeight0, pixelWiseWeight1)).cuda()

            if consistency_loss == 'MSE':
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(targets_u.cpu()))
                #pseudo_label = torch.cat((pseudo_label[1].unsqueeze(0),pseudo_label[0].unsqueeze(0)))
                L_u = consistency_weight * unlabeled_weight * unlabeled_loss(
                    logits_u_s, pseudo_label)
            elif consistency_loss == 'CE':
                L_u = consistency_weight * unlabeled_loss(
                    logits_u_s, targets_u, pixelWiseWeight)

            loss = L_l + L_u

        else:
            loss = L_l

        if len(gpus) > 1:
            #print('before mean = ',loss)
            loss = loss.mean()
            #print('after mean = ',loss)
            loss_l_value += L_l.mean().item()
            if train_unlabeled:
                loss_u_value += L_u.mean().item()
        else:
            loss_l_value += L_l.item()
            if train_unlabeled:
                loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()

        # update Mean teacher network
        if ema_model is not None:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model=ema_model,
                                             model=model,
                                             alpha_teacher=alpha_teacher,
                                             iteration=i_iter)

        print(
            'iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.format(
                i_iter, num_iterations, loss_l_value, loss_u_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            if epochs_since_start * len(trainloader) < save_checkpoint_every:
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 overwrite=False)
            else:
                _save_checkpoint(i_iter, model, optimizer, config, ema_model)

        if config['utils']['tensorboard']:
            if 'tensorboard_writer' not in locals():
                tensorboard_writer = tensorboard.SummaryWriter(log_dir,
                                                               flush_secs=30)

            accumulated_loss_l.append(loss_l_value)
            if train_unlabeled:
                accumulated_loss_u.append(loss_u_value)
            if i_iter % log_per_iter == 0 and i_iter != 0:

                tensorboard_writer.add_scalar('Training/Supervised loss',
                                              np.mean(accumulated_loss_l),
                                              i_iter)
                accumulated_loss_l = []

                if train_unlabeled:
                    tensorboard_writer.add_scalar('Training/Unsupervised loss',
                                                  np.mean(accumulated_loss_u),
                                                  i_iter)
                    accumulated_loss_u = []

        if i_iter % val_per_iter == 0 and i_iter != 0:
            model.eval()
            if dataset == 'cityscapes':
                mIoU, eval_loss = evaluate(model,
                                           dataset,
                                           ignore_label=250,
                                           input_size=(512, 1024),
                                           save_dir=checkpoint_dir)
            elif dataset == 'multiview':
                mIoU, eval_loss = evaluate(model,
                                           dataset,
                                           ignore_label=255,
                                           input_size=(300, 300),
                                           save_dir=checkpoint_dir)
            else:
                print('erro dataset: {}'.format(dataset))
            model.train()

            if mIoU > best_mIoU and save_best_model:
                best_mIoU = mIoU
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 save_best=True)

            if config['utils']['tensorboard']:
                tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
                tensorboard_writer.add_scalar('Validation/Loss', eval_loss,
                                              i_iter)
                print('iter {}, mIoU: {}'.format(mIoU, i_iter))

        if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0:
            # Saves two mixed images and the corresponding prediction
            save_image(inputs_u_s[0].cpu(), i_iter, 'input1',
                       palette.CityScpates_palette)
            save_image(inputs_u_s[1].cpu(), i_iter, 'input2',
                       palette.CityScpates_palette)
            _, pred_u_s = torch.max(logits_u_s, dim=1)
            save_image(pred_u_s[0].cpu(), i_iter, 'pred1',
                       palette.CityScpates_palette)
            save_image(pred_u_s[1].cpu(), i_iter, 'pred2',
                       palette.CityScpates_palette)

    _save_checkpoint(num_iterations, model, optimizer, config, ema_model)

    model.eval()
    if dataset == 'cityscapes':
        mIoU, val_loss = evaluate(model,
                                  dataset,
                                  ignore_label=250,
                                  input_size=(512, 1024),
                                  save_dir=checkpoint_dir)
    elif dataset == 'multiview':
        mIoU, val_loss = evaluate(model,
                                  dataset,
                                  ignore_label=255,
                                  input_size=(300, 300),
                                  save_dir=checkpoint_dir)
    else:
        print('erro dataset: {}'.format(dataset))
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter,
                         model,
                         optimizer,
                         config,
                         ema_model,
                         save_best=True)

    if config['utils']['tensorboard']:
        tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
        tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter)

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + 'seconds')
Example #6
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)
    model.cuda()

    # load pretrained parameters
    saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    if args.dataset == 'pascal_voc':
        train_dataset = VOCDataSet(args.data_dir,
                                   args.data_list,
                                   crop_size=input_size,
                                   scale=args.random_scale,
                                   mirror=args.random_mirror,
                                   mean=IMG_MEAN)
    elif args.dataset == 'CMR':
        train_dataset = CMRDataSet(args.data_dir,
                                   args.data_list,
                                   crop_size=input_size,
                                   scale=args.random_scale,
                                   mirror=args.random_mirror,
                                   mean=IMG_MEAN)
    elif args.dataset == 'pascal_context':
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        data_kwargs = {
            'transform': input_transform,
            'base_size': 505,
            'crop_size': 321
        }
        #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs)
        data_loader = get_loader('pascal_context')
        data_path = get_data_path('pascal_context')
        train_dataset = data_loader(data_path,
                                    split='train',
                                    mode='train',
                                    **data_kwargs)

    elif args.dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        data_aug = Compose(
            [RandomCrop_city((256, 512)),
             RandomHorizontallyFlip()])
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    if args.labeled_ratio is None:
        trainloader = data.DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=args.batch_size,
                                      num_workers=4,
                                      pin_memory=True)
    else:
        partial_size = int(args.labeled_ratio * train_dataset_size)

        if args.split_id is not None:
            train_ids = pickle.load(open(args.split_id, 'rb'))
            print('loading train ids from {}'.format(args.split_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(os.path.join(args.checkpoint_dir, 'split.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=4,
                                      pin_memory=True)

    trainloader_iter = iter(trainloader)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # loss/ bilinear upsampling
    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    for i_iter in range(args.num_steps):

        loss_value = 0
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        try:
            batch_lab = next(trainloader_iter)
        except:
            trainloader_iter = iter(trainloader)
            batch_lab = next(trainloader_iter)

        images, labels, _, _, index = batch_lab
        images = Variable(images).cuda(args.gpu)

        pred = interp(model(images))
        loss = loss_calc(pred, labels, args.gpu)

        loss.backward()
        loss_value += loss.item()

        optimizer.step()

        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
            i_iter, args.num_steps, loss_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.checkpoint_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('saving checkpoint ...')
            torch.save(
                model.state_dict(),
                osp.join(args.checkpoint_dir, 'VOC_' + str(i_iter) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #7
0
def main():
    print(args)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    saved_state_dict = torch.load(args.restore_from)

    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = s4GAN_discriminator(num_classes=args.num_classes,
                                  dataset=args.dataset)

    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    if args.dataset == 'pascal_voc':
        train_dataset = VOCDataSet(args.data_dir,
                                   args.data_list,
                                   crop_size=input_size,
                                   scale=args.random_scale,
                                   mirror=args.random_mirror,
                                   mean=IMG_MEAN)
        #train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
        #scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    elif args.dataset == 'pascal_context':
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.406, .456, .485], [.229, .224, .225])
        ])
        data_kwargs = {
            'transform': input_transform,
            'base_size': 505,
            'crop_size': 321
        }
        #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs)
        data_loader = get_loader('pascal_context')
        data_path = get_data_path('pascal_context')
        train_dataset = data_loader(data_path,
                                    split='train',
                                    mode='train',
                                    **data_kwargs)
        #train_gt_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs)

    elif args.dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        data_aug = Compose(
            [RandomCrop_city((256, 512)),
             RandomHorizontallyFlip()])
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug)
        #train_gt_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    if args.labeled_ratio is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=4,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True)

        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    else:
        partial_size = int(args.labeled_ratio * train_dataset_size)

        if args.split_id is not None:
            train_ids = pickle.load(open(args.split_id, 'rb'))
            print('loading train ids from {}'.format(args.split_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(
            train_ids,
            open(os.path.join(args.checkpoint_dir, 'train_voc_split.pkl'),
                 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=4,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=4,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=4,
                                         pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    y_real_, y_fake_ = Variable(torch.ones(args.batch_size,
                                           1).cuda()), Variable(
                                               torch.zeros(args.batch_size,
                                                           1).cuda())

    for i_iter in range(args.num_steps):

        loss_ce_value = 0
        loss_D_value = 0
        loss_fm_value = 0
        loss_S_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train Segmentation Network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # training loss for labeled data only
        try:
            batch = next(trainloader_iter)
        except:
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        images, labels, _, _, _ = batch
        images = Variable(images).cuda(args.gpu)
        pred = interp(model(images))
        loss_ce = loss_calc(pred, labels,
                            args.gpu)  # Cross entropy loss for labeled data

        #training loss for remaining unlabeled data
        try:
            batch_remain = next(trainloader_remain_iter)
        except:
            trainloader_remain_iter = iter(trainloader_remain)
            batch_remain = next(trainloader_remain_iter)

        images_remain, _, _, _, _ = batch_remain
        images_remain = Variable(images_remain).cuda(args.gpu)
        pred_remain = interp(model(images_remain))

        # concatenate the prediction with the input images
        images_remain = (images_remain - torch.min(images_remain)) / (
            torch.max(images_remain) - torch.min(images_remain))
        #print (pred_remain.size(), images_remain.size())
        pred_cat = torch.cat((F.softmax(pred_remain, dim=1), images_remain),
                             dim=1)

        D_out_z, D_out_y_pred = model_D(
            pred_cat)  # predicts the D ouput 0-1 and feature map for FM-loss

        # find predicted segmentation maps above threshold
        pred_sel, labels_sel, count = find_good_maps(D_out_z, pred_remain)

        # training loss on above threshold segmentation predictions (Cross Entropy Loss)
        if count > 0 and i_iter > 0:
            loss_st = loss_calc(pred_sel, labels_sel, args.gpu)
        else:
            loss_st = 0.0

        # Concatenates the input images and ground-truth maps for the Districrimator 'Real' input
        try:
            batch_gt = next(trainloader_gt_iter)
        except:
            trainloader_gt_iter = iter(trainloader_gt)
            batch_gt = next(trainloader_gt_iter)

        images_gt, labels_gt, _, _, _ = batch_gt
        # Converts grounth truth segmentation into 'num_classes' segmentation maps.
        D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)

        images_gt = images_gt.cuda()
        images_gt = (images_gt - torch.min(images_gt)) / (torch.max(images) -
                                                          torch.min(images))

        D_gt_v_cat = torch.cat((D_gt_v, images_gt), dim=1)
        D_out_z_gt, D_out_y_gt = model_D(D_gt_v_cat)

        # L1 loss for Feature Matching Loss
        loss_fm = torch.mean(
            torch.abs(torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0)))

        if count > 0 and i_iter > 0:  # if any good predictions found for self-training loss
            loss_S = loss_ce + args.lambda_fm * loss_fm + args.lambda_st * loss_st
        else:
            loss_S = loss_ce + args.lambda_fm * loss_fm

        loss_S.backward()
        loss_fm_value += args.lambda_fm * loss_fm

        loss_ce_value += loss_ce.item()
        loss_S_value += loss_S.item()

        # train D
        for param in model_D.parameters():
            param.requires_grad = True

        # train with pred
        pred_cat = pred_cat.detach(
        )  # detach does not allow the graddients to back propagate.

        D_out_z, _ = model_D(pred_cat)
        y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
        loss_D_fake = criterion(D_out_z, y_fake_)

        # train with gt
        D_out_z_gt, _ = model_D(D_gt_v_cat)
        y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())
        loss_D_real = criterion(D_out_z_gt, y_real_)

        loss_D = (loss_D_fake + loss_D_real) / 2.0
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        print(
            'iter = {0:8d}/{1:8d}, loss_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'
            .format(i_iter, args.num_steps, loss_ce_value, loss_fm_value,
                    loss_S_value, loss_D_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('saving checkpoint  ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #8
0
def main():
    print(args)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    summary(model, (110, 64, 64))

    # load pretrained parameters
    saved_state_dict = torch.load(args.restore_from)

    #new_params = model.state_dict().copy()
    #for name, param in new_params.items():
    #    if name in saved_state_dict and param.size() == saved_state_dict[name].size():
    #        new_params[name].copy_(saved_state_dict[name])
    #model.load_state_dict(saved_state_dict)

    # init D
    model_D = s4GAN_discriminator(num_classes=args.num_classes,
                                  dataset=args.dataset)

    fig, axarr = plt.subplots(6, sharex=False, figsize=(20, 20))

    #p_class = [0.0355738,  0.00141609, 0.48528844, 0.07337901, 0.0388637,  0.1026877,
    #0.00271774, 0.18383373, 0.0359457,  0.,         0.,         0.02593297,
    #0.01436112, 0.,         0.]  # small

    p_class = [
        7.25703402e-02, 1.57180553e-01, 1.81395714e-01, 2.15331438e-01,
        8.59744781e-02, 6.45834114e-02, 2.08535688e-03, 2.95754679e-02,
        2.30909954e-02, 1.18364523e-03, 5.40670110e-04, 4.34120229e-02,
        1.03664125e-01, 2.07385748e-03, 1.73379244e-02
    ]

    #p_class = [0.0688357,  0.16084504, 0.17536666, 0.19976606, 0.12290404, 0.06232464,
    # 0.00192151, 0.02954287, 0.02216445, 0.0013851,  0.00159172, 0.03166692,
    # 0.1009182,  0.00169153, 0.01907556]  # 50%

    weights_t = (1 / torch.log(1.02 + torch.tensor(p_class))).cuda()

    model_D = torch.nn.DataParallel(model_D).cuda()

    #if args.restore_from_D is not None:
    #model_D.load_state_dict(torch.load(args.restore_from_D))

    cudnn.benchmark = True

    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
        #os.makedirs(os.path.join(args.checkpoint_dir,'rend'))

    if args.dataset == 'pascal_voc':
        train_dataset = VOCDataSet(args.data_dir,
                                   args.data_list,
                                   crop_size=input_size,
                                   scale=args.random_scale,
                                   mirror=args.random_mirror,
                                   mean=IMG_MEAN)
        #train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
        #scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    elif args.dataset == 'pascal_context':
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.406, .456, .485], [.229, .224, .225])
        ])
        data_kwargs = {
            'transform': input_transform,
            'base_size': 505,
            'crop_size': 321
        }
        #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs)
        data_loader = get_loader('pascal_context')
        data_path = get_data_path('pascal_context')
        train_dataset = data_loader(data_path,
                                    split='train',
                                    mode='train',
                                    **data_kwargs)
        #train_gt_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs)

    elif args.dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        data_aug = Compose(
            [RandomCrop_city((256, 512)),
             RandomHorizontallyFlip()])
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug)
        #train_gt_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug)

    elif args.dataset == 'sen2':
        data_loader = get_loader('sen2')
        data_path = get_data_path('sen2')
        train_dataset = get_dataloader(data_loader, data_path, input_size,
                                       ["labelled_train"], "train_sup")
        test_dataset = get_dataloader(data_loader, data_path, input_size,
                                      ["labelled_test"], "test")

    train_dataset_size = len(train_dataset)
    test_dataset_size = len(test_dataset)
    print('train dataset size: ', train_dataset_size)
    print('test dataset size: ', test_dataset_size)

    num_batches_train = int(train_dataset_size / args.batch_size) + 1
    last_batch_sz = train_dataset_size % args.batch_size

    num_batches_test = int(test_dataset_size / args.batch_size) + 1

    print('num batches train : ', num_batches_train)
    print('last batch size : ', last_batch_sz)

    if args.labeled_ratio is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=4,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True)

        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

        testloader = data.DataLoader(test_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=4,
                                     pin_memory=True)

    else:
        partial_size = int(args.labeled_ratio * train_dataset_size)

        if args.split_id is not None:
            train_ids = pickle.load(open(args.split_id, 'rb'))
            print('loading train ids from {}'.format(args.split_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(
            train_ids,
            open(os.path.join(args.checkpoint_dir, 'train_voc_split.pkl'),
                 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=4,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=4,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=4,
                                         pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    y_real_, y_fake_ = Variable(torch.ones(args.batch_size,
                                           1).cuda()), Variable(
                                               torch.zeros(args.batch_size,
                                                           1).cuda())

    train_accs = []
    test_accs = []
    train_kscores = []
    test_kscores = []
    losses_ce = []
    losses_fm = []
    losses_S = []
    losses_D = []

    e_i = 0
    loss_ce_value = 0
    loss_D_value = 0
    loss_fm_value = 0
    loss_S_value = 0
    for i_iter in range(args.num_steps):

        #loss_ce_value = 0
        #loss_D_value = 0
        #loss_fm_value = 0
        #loss_S_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train Segmentation Network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # training loss for labeled data only
        try:
            batch = next(trainloader_iter)
        except:
            print("end epoch %s" % e_i)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        images, labels = batch
        images = Variable(images).cuda(args.gpu)
        pred = F.interpolate(model(images),
                             size=(input_size[0], input_size[1]),
                             mode='bilinear',
                             align_corners=True)
        loss_ce = loss_calc(pred, labels, weights_t,
                            args.gpu)  # Cross entropy loss for labeled data

        #training loss for remaining unlabeled data
        try:
            batch_remain = next(trainloader_remain_iter)
        except:
            trainloader_remain_iter = iter(trainloader_remain)
            batch_remain = next(trainloader_remain_iter)

        images_remain, labels_remain = batch_remain
        images_remain = Variable(images_remain).cuda(args.gpu)
        pred_remain = F.interpolate(model(images_remain),
                                    size=(input_size[0], input_size[1]),
                                    mode='bilinear',
                                    align_corners=True)

        # concatenate the prediction with the input images
        images_remain = (images_remain - torch.min(images_remain)) / (
            torch.max(images_remain) - torch.min(images_remain))
        #print (pred_remain.size(), images_remain.size())
        n, c, w, h = pred_remain.size()
        mask = (labels_remain != args.ignore_label)
        pred_remain[mask.view(n, 1, w, h).repeat(1, c, 1, 1)] = 0
        pred_cat = torch.cat((F.softmax(pred_remain, dim=1), images_remain),
                             dim=1)

        D_out_z, D_out_y_pred = model_D(
            pred_cat
        )  # predicts the D ouput 0-1 and feature map for FM-loss D_out_y_pred

        # find predicted segmentation maps above threshold
        pred_sel, labels_sel, count = find_good_maps(D_out_z, pred_remain)

        # training loss on above threshold segmentation predictions (Cross Entropy Loss)
        if count > 0 and i_iter > 0:
            loss_st = loss_calc(pred_sel, labels_sel, weights_t, args.gpu)
        else:
            loss_st = 0.0

        # Concatenates the input images and ground-truth maps for the Districrimator 'Real' input
        try:
            batch_gt = next(trainloader_gt_iter)
        except:
            trainloader_gt_iter = iter(trainloader_gt)
            batch_gt = next(trainloader_gt_iter)

        images_gt, labels_gt = batch_gt
        # Converts grounth truth segmentation into 'num_classes' segmentation maps.
        D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)

        images_gt = images_gt.cuda()
        images_gt = (images_gt - torch.min(images_gt)) / (torch.max(images) -
                                                          torch.min(images))

        D_gt_v_cat = torch.cat((D_gt_v, images_gt), dim=1)
        D_out_z_gt, D_out_y_gt = model_D(D_gt_v_cat)  # D_out_y_gt

        # L1 loss for Feature Matching Loss
        loss_fm = torch.mean(
            torch.abs(torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0)))

        if count > 0 and i_iter > 0:  # if any good predictions found for self-training loss
            loss_S = loss_ce + args.lambda_st * loss_st + args.lambda_fm * loss_fm
        else:
            loss_S = loss_ce + args.lambda_fm * loss_fm

        loss_S.backward()
        loss_fm_value += args.lambda_fm * loss_fm

        loss_ce_value += loss_ce.item()
        loss_S_value += loss_S.item()

        # train D
        for param in model_D.parameters():
            param.requires_grad = True

        # train with pred
        pred_cat = pred_cat.detach(
        )  # detach does not allow the graddients to back propagate.

        D_out_z, _ = model_D(pred_cat)
        y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
        loss_D_fake = criterion(D_out_z, y_fake_)

        # train with gt
        D_out_z_gt, _ = model_D(D_gt_v_cat)
        y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())
        loss_D_real = criterion(D_out_z_gt, y_real_)

        loss_D = (loss_D_fake + loss_D_real) / 2.0
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        print(
            'iter = {0:8d}/{1:8d}, loss_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'
            .format(i_iter, args.num_steps, loss_ce_value, loss_fm_value,
                    loss_S_value, loss_D_value))

        # EVALUATION
        # -------------------------------------------------------------------------------

        if (i_iter % (num_batches_train - 1) == 0) & (i_iter > 0):
            e_i += 1
            train_acc, train_kappa_score, train_cm = _eval(model,
                                                           trainloader,
                                                           args.batch_size,
                                                           input_size,
                                                           num_batches_train,
                                                           args.gpu,
                                                           render_preds=False)
            test_acc, test_kappa_score, test_cm = _eval(model,
                                                        testloader,
                                                        args.batch_size,
                                                        input_size,
                                                        num_batches_test,
                                                        args.gpu,
                                                        render_preds=False)

            print('train_acc = ', train_acc)
            print('test_acc = ', test_acc)

            cm_name_train = "confusion_matrix_%d.png" % e_i
            plot_confusion_matrix(cm=train_cm,
                                  classes=(range(args.num_classes)),
                                  normalize=True)
            plt.savefig(os.path.join(args.checkpoint_dir, cm_name_train))

            avg_loss_ce = loss_ce_value / num_batches_train
            avg_loss_fm = loss_fm_value / num_batches_train
            avg_loss_S = loss_S_value / num_batches_train
            avg_loss_D = loss_D_value / num_batches_train

            loss_ce_value = 0
            loss_D_value = 0
            loss_fm_value = 0
            loss_S_value = 0

            train_accs.append(train_acc)
            test_accs.append(test_acc)
            train_kscores.append(train_kappa_score)
            test_kscores.append(test_kappa_score)
            losses_ce.append(avg_loss_ce)
            losses_fm.append(avg_loss_fm)
            losses_S.append(avg_loss_S)
            losses_D.append(avg_loss_D)

            axarr[0].clear()
            axarr[0].plot(train_accs, 'g')
            axarr[0].plot(test_accs, 'r')
            axarr[0].set_title("acc (best), top train : %f" % max(train_accs))

            axarr[1].clear()
            axarr[1].plot(train_kscores, 'g')
            axarr[1].plot(test_kscores, 'r')
            axarr[1].set_title("Cohen's kappa score (best) : %f" %
                               max(train_kscores))

            axarr[2].clear()
            axarr[2].plot(losses_D)
            axarr[2].set_title("Loss D")

            axarr[3].clear()
            axarr[3].plot(losses_S)
            axarr[3].set_title("Loss S")

            axarr[4].clear()
            axarr[4].plot(losses_ce)
            axarr[4].set_title("Loss ce")

            axarr[5].clear()
            axarr[5].plot(losses_fm)
            axarr[5].set_title("Loss fm")

            fig.canvas.draw_idle()
            fig.savefig(os.path.join(args.checkpoint_dir, "plots.png"))

            if train_acc >= max(train_accs):
                print('save model ...')
                torch.save(model.state_dict(),
                           os.path.join(args.checkpoint_dir, 'best.pth'))
                torch.save(model_D.state_dict(),
                           os.path.join(args.checkpoint_dir, 'best_D.pth'))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(model.state_dict(),
                       os.path.join(args.checkpoint_dir, 'latest.pth'))
            torch.save(model_D.state_dict(),
                       os.path.join(args.checkpoint_dir, 'latest_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('saving checkpoint  ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'sen2_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.checkpoint_dir,
                             'sen2_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #9
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()
    gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = Res_Deeplab(num_classes=args.num_classes)
    model.cuda()

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    if args.dataset == 'pascal_voc':
        testloader = data.DataLoader(VOCDataSet(args.data_dir,
                                                args.data_list,
                                                crop_size=(505, 505),
                                                mean=IMG_MEAN,
                                                scale=False,
                                                mirror=False),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)
        interp = nn.Upsample(size=(505, 505),
                             mode='bilinear',
                             align_corners=True)

    elif args.dataset == 'pascal_context':
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        data_kwargs = {
            'transform': input_transform,
            'base_size': 512,
            'crop_size': 512
        }
        data_loader = get_loader('pascal_context')
        data_path = get_data_path('pascal_context')
        test_dataset = data_loader(data_path,
                                   split='val',
                                   mode='val',
                                   **data_kwargs)
        testloader = data.DataLoader(test_dataset,
                                     batch_size=1,
                                     drop_last=False,
                                     shuffle=False,
                                     num_workers=1,
                                     pin_memory=True)
        interp = nn.Upsample(size=(512, 512),
                             mode='bilinear',
                             align_corners=True)

    elif args.dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        test_dataset = data_loader(data_path,
                                   img_size=(512, 1024),
                                   is_transform=True,
                                   split='val')
        testloader = data.DataLoader(test_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)
        interp = nn.Upsample(size=(512, 1024),
                             mode='bilinear',
                             align_corners=True)

    data_list = []
    colorize = VOCColorize()

    if args.with_mlmt:
        mlmt_preds = np.loadtxt('mlmt_output/output_ema_p_1_0_voc_5.txt',
                                dtype=float)  # best mt 0.05

        mlmt_preds[mlmt_preds >= 0.2] = 1
        mlmt_preds[mlmt_preds < 0.2] = 0

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % (index))
        image, label, size, name, _ = batch
        size = size[0]
        output = model(Variable(image, volatile=True).cuda(gpu0))
        output = interp(output).cpu().data[0].numpy()

        if args.dataset == 'pascal_voc':
            output = output[:, :size[0], :size[1]]
            gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)
        elif args.dataset == 'pascal_context':
            gt = np.asarray(label[0].numpy(), dtype=np.int)
        elif args.dataset == 'cityscapes':
            gt = np.asarray(label[0].numpy(), dtype=np.int)

        if args.with_mlmt:
            for i in range(args.num_classes):
                output[i] = output[i] * mlmt_preds[index][i]

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

        if args.save_output_images:
            if args.dataset == 'pascal_voc':
                filename = os.path.join(args.save_dir,
                                        '{}.png'.format(name[0]))
                color_file = Image.fromarray(
                    colorize(output).transpose(1, 2, 0), 'RGB')
                color_file.save(filename)
            elif args.dataset == 'pascal_context':
                filename = os.path.join(args.save_dir, filename[0])
                scipy.misc.imsave(filename, gt)

        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result.txt')
    get_iou(args, data_list, args.num_classes, filename)
def main():
    print(config)
    cudnn.enabled = True
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.backends.cudnn.deterministic = True

    if pretraining == 'COCO':
        from utils.transformsgpu import normalize_bgr as normalize
    else:
        from utils.transformsgpu import normalize_rgb as normalize

    batch_size_unlabeled = int(batch_size / 2)
    batch_size_labeled = int(batch_size * 1)

    RAMP_UP_ITERS = 2000

    data_loader = get_loader('cityscapes')
    data_path = get_data_path('cityscapes')
    data_aug = Compose([
        RandomCrop_city(input_size)
    ])  # from 1024x2048 to resize 512x1024 to crop input_size (512x512)
    train_dataset = data_loader(data_path,
                                is_transform=True,
                                augmentations=data_aug,
                                img_size=input_size,
                                pretraining=pretraining)

    from data.gta5_loader import gtaLoader
    data_loader_gta = gtaLoader
    data_path_gta = get_data_path('gta5')
    data_aug_gta = Compose([
        RandomCrop_city(input_size)
    ])  # from 1024x2048 to resize 512x1024 to crop input_size (512x512)
    train_dataset_gta = data_loader_gta(data_path_gta,
                                        is_transform=True,
                                        augmentations=data_aug_gta,
                                        img_size=input_size,
                                        pretraining=pretraining)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)

    class_weights_curr = ClassBalancing(
        labeled_iters=int(labeled_samples / batch_size_labeled),
        unlabeled_iters=int(
            (train_dataset_size - labeled_samples) / batch_size_unlabeled),
        n_classes=num_classes)

    feature_memory = FeatureMemory(num_samples=labeled_samples,
                                   dataset=dataset,
                                   memory_per_class=256,
                                   feature_size=256,
                                   n_classes=num_classes)

    # select the partition
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size_labeled,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    # GTA5
    train_ids_gta = np.arange(len(train_dataset_gta))
    np.random.shuffle(train_ids_gta)
    train_sampler_gta = data.sampler.SubsetRandomSampler(train_ids_gta)
    trainloader_gta = data.DataLoader(train_dataset_gta,
                                      batch_size=batch_size_labeled,
                                      sampler=train_sampler_gta,
                                      num_workers=num_workers,
                                      pin_memory=True)
    trainloader_iter_gta = iter(trainloader_gta)

    train_remain_sampler = data.sampler.SubsetRandomSampler(
        train_ids[partial_size:])
    trainloader_remain = data.DataLoader(train_dataset,
                                         batch_size=batch_size_unlabeled,
                                         sampler=train_remain_sampler,
                                         num_workers=num_workers,
                                         pin_memory=True)
    trainloader_remain_iter = iter(trainloader_remain)

    # LOSSES
    unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda()
    supervised_loss = CrossEntropy2d(ignore_label=ignore_label).cuda()
    ''' Deeplab model '''
    # Define network
    if deeplabv2:
        if pretraining == 'COCO':  # coco and iamgenet resnet architectures differ a little, just on how to do the stride
            from model.deeplabv2 import Res_Deeplab
        else:  # imagenet pretrained (more modern modification)
            from model.deeplabv2_imagenet import Res_Deeplab

    else:
        from model.deeplabv3 import Res_Deeplab

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    if pretraining == 'COCO':
        saved_state_dict = model_zoo.load_url(
            'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth'
        )  # COCO pretraining
    else:
        saved_state_dict = model_zoo.load_url(
            'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
        )  # iamgenet pretrainning

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])

    model.load_state_dict(new_params)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    optimizer = torch.optim.SGD(model.optim_parameters(learning_rate_object),
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    ema_model = create_ema_model(model, Res_Deeplab)
    ema_model.train()
    ema_model = ema_model.cuda()
    model.train()
    model = model.cuda()
    cudnn.benchmark = True

    # checkpoint = torch.load('/home/snowflake/Escritorio/Semi-Sup/saved/Deep_cont/best_model.pth')
    # model.load_state_dict(checkpoint['model'])

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    epochs_since_start = 0
    start_iteration = 0
    best_mIoU = 0  # best metric while training
    iters_without_improve = 0

    # TRAINING
    for i_iter in range(start_iteration, num_iterations):
        model.train()  # set mode to training
        optimizer.zero_grad()
        a = time.time()

        loss_l_value = 0.
        adjust_learning_rate(optimizer, i_iter)
        ''' LABELED SAMPLES '''
        # Get batch
        is_cityscapes = i_iter % 2 == 0
        is_gta = not is_cityscapes
        if num_iterations - i_iter > 100:
            # Last 100 itereations only citysacpes data
            is_cityscapes = True

        if is_cityscapes:
            try:
                batch = next(trainloader_iter)
                if batch[0].shape[0] != batch_size_labeled:
                    batch = next(trainloader_iter)
            except:  # finish epoch, rebuild the iterator
                epochs_since_start = epochs_since_start + 1
                # print('Epochs since start: ',epochs_since_start)
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)
        else:
            try:
                batch = next(trainloader_iter_gta)
                if batch[0].shape[0] != batch_size_labeled:
                    train_ids_gta = np.arange(len(train_dataset_gta))
                    np.random.shuffle(train_ids_gta)
                    train_sampler_gta = data.sampler.SubsetRandomSampler(
                        train_ids_gta)
                    trainloader_gta = data.DataLoader(
                        train_dataset_gta,
                        batch_size=batch_size_labeled,
                        sampler=train_sampler_gta,
                        num_workers=num_workers,
                        pin_memory=True)
                    trainloader_iter_gta = iter(trainloader_gta)
                    batch = next(trainloader_iter_gta)
            except:  # finish epoch, rebuild the iterator
                # print('Epochs since start: ',epochs_since_start)
                trainloader_iter_gta = iter(trainloader_gta)
                batch = next(trainloader_iter_gta)

        images, labels, _, _, _ = batch
        images = images.cuda()
        labels = labels.cuda()
        ''' UNLABELED SAMPLES '''
        try:
            batch_remain = next(trainloader_remain_iter)
            if batch_remain[0].shape[0] != batch_size_unlabeled:
                batch_remain = next(trainloader_remain_iter)
        except:
            trainloader_remain_iter = iter(trainloader_remain)
            batch_remain = next(trainloader_remain_iter)

        # Unlabeled
        unlabeled_images, _, _, _, _ = batch_remain
        unlabeled_images = unlabeled_images.cuda()

        # Create pseudolabels
        with torch.no_grad():
            if use_teacher:
                logits_u_w, features_weak_unlabeled = ema_model(
                    normalize(unlabeled_images, dataset), return_features=True)
            else:
                model.eval()
                logits_u_w, features_weak_unlabeled = model(
                    normalize(unlabeled_images, dataset), return_features=True)
            logits_u_w = interp(logits_u_w).detach()  # prediction unlabeled
            softmax_u_w = torch.softmax(logits_u_w, dim=1)
            max_probs, pseudo_label = torch.max(softmax_u_w,
                                                dim=1)  # Get pseudolabels

        model.train()

        if is_cityscapes:
            class_weights_curr.add_frequencies(labels.cpu().numpy(),
                                               pseudo_label.cpu().numpy())

        images2, labels2, _, _ = augment_samples(images,
                                                 labels,
                                                 None,
                                                 random.random() < 0.25,
                                                 batch_size_labeled,
                                                 ignore_label,
                                                 weak=True)
        '''
        UNLABELED DATA
        '''
        '''
        CROSS ENTROPY FOR UNLABELED USING PSEUDOLABELS
        Once you have the speudolabel, perform strong augmetnation to force the netowrk to yield lower confidence scores for pushing them up
        '''

        do_classmix = i_iter > RAMP_UP_ITERS and random.random(
        ) < 0.75  # only after rampup perfrom classmix
        unlabeled_images_aug1, pseudo_label1, max_probs1, unlabeled_aug1_params = augment_samples(
            unlabeled_images, pseudo_label, max_probs, do_classmix,
            batch_size_unlabeled, ignore_label)

        do_classmix = i_iter > RAMP_UP_ITERS and random.random(
        ) < 0.75  # only after rampup perfrom classmix

        unlabeled_images_aug2, pseudo_label2, max_probs2, unlabeled_aug2_params = augment_samples(
            unlabeled_images, pseudo_label, max_probs, do_classmix,
            batch_size_unlabeled, ignore_label)

        joined_unlabeled = torch.cat(
            (unlabeled_images_aug1, unlabeled_images_aug2), dim=0)
        joined_pseudolabels = torch.cat((pseudo_label1, pseudo_label2), dim=0)
        joined_maxprobs = torch.cat((max_probs1, max_probs2), dim=0)

        pred_joined_unlabeled, features_joined_unlabeled = model(
            normalize(joined_unlabeled, dataset), return_features=True)
        pred_joined_unlabeled = interp(pred_joined_unlabeled)

        joined_labeled = images2
        joined_labels = labels2
        labeled_pred, labeled_features = model(normalize(
            joined_labeled, dataset),
                                               return_features=True)
        labeled_pred = interp(labeled_pred)

        class_weights = torch.from_numpy(np.ones((num_classes))).cuda()
        if i_iter > RAMP_UP_ITERS:
            class_weights = torch.from_numpy(
                class_weights_curr.get_weights(num_iterations,
                                               only_labeled=False)).cuda()

        loss = 0
        # SUPERVISED SEGMENTATION
        labeled_loss = supervised_loss(labeled_pred,
                                       joined_labels,
                                       weight=class_weights.float())  #
        loss = loss + labeled_loss

        # SELF-SUPERVISED SEGMENTATION
        '''
        Cross entropy loss using pseudolabels. 
        '''

        unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
            ignore_index=ignore_label, weight=class_weights.float()).cuda()  #

        # Pseudo-label weighting
        pixelWiseWeight = sigmoid_ramp_up(i_iter, RAMP_UP_ITERS) * torch.ones(
            joined_maxprobs.shape).cuda()
        pixelWiseWeight = pixelWiseWeight * torch.pow(joined_maxprobs.detach(),
                                                      6)

        # Pseudo-label loss
        loss_ce_unlabeled = unlabeled_loss(pred_joined_unlabeled,
                                           joined_pseudolabels,
                                           pixelWiseWeight)

        loss = loss + loss_ce_unlabeled

        # entropy loss
        valid_mask = (joined_pseudolabels != ignore_label).unsqueeze(1)
        loss = loss + entropy_loss(
            torch.nn.functional.softmax(pred_joined_unlabeled, dim=1),
            valid_mask) * 0.01

        # CONTRASTIVE LEARNING
        if is_cityscapes:
            if i_iter > RAMP_UP_ITERS - 1000:
                # Build Memory Bank 1000 iters before starting to do contrsative
                with torch.no_grad():
                    if use_teacher:
                        labeled_pred_ema, labeled_features_ema = ema_model(
                            normalize(joined_labeled, dataset),
                            return_features=True)
                    else:
                        model.eval()
                        labeled_pred_ema, labeled_features_ema = model(
                            normalize(joined_labeled, dataset),
                            return_features=True)
                        model.train()

                    labeled_pred_ema = interp(labeled_pred_ema)
                    probability_prediction_ema, label_prediction_ema = torch.max(
                        torch.softmax(labeled_pred_ema,
                                      dim=1), dim=1)  # Get pseudolabels

                labels_down = nn.functional.interpolate(
                    joined_labels.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                label_prediction_down = nn.functional.interpolate(
                    label_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)
                probability_prediction_down = nn.functional.interpolate(
                    probability_prediction_ema.float().unsqueeze(1),
                    size=(labeled_features_ema.shape[2],
                          labeled_features_ema.shape[3]),
                    mode='nearest').squeeze(1)

                # get mask where the labeled predictions are correct
                mask_prediction_correctly = (
                    (label_prediction_down == labels_down).float() *
                    (probability_prediction_down > 0.95).float()).bool()

                labeled_features_correct = labeled_features_ema.permute(
                    0, 2, 3, 1)
                labels_down_correct = labels_down[mask_prediction_correctly]
                labeled_features_correct = labeled_features_correct[
                    mask_prediction_correctly, ...]

                # get projected features
                with torch.no_grad():
                    if use_teacher:
                        proj_labeled_features_correct = ema_model.projection_head(
                            labeled_features_correct)
                    else:
                        model.eval()
                        proj_labeled_features_correct = model.projection_head(
                            labeled_features_correct)
                        model.train()

                # updated memory bank
                feature_memory.add_features_from_sample_learned(
                    ema_model, proj_labeled_features_correct,
                    labels_down_correct, batch_size_labeled)

        if i_iter > RAMP_UP_ITERS:
            '''
            LABELED TO LABELED. Force features from laeled samples, to be similar to other features from the same class (which also leads to good predictions)

            '''

            # now we can take all. as they are not the prototypes, here we are gonan force these features to be similar as the correct ones
            mask_prediction_correctly = (labels_down != ignore_label)

            labeled_features_all = labeled_features.permute(0, 2, 3, 1)
            labels_down_all = labels_down[mask_prediction_correctly]
            labeled_features_all = labeled_features_all[
                mask_prediction_correctly, ...]

            # get prediction features
            proj_labeled_features_all = model.projection_head(
                labeled_features_all)
            pred_labeled_features_all = model.prediction_head(
                proj_labeled_features_all)

            loss_contr_labeled = contrastive_class_to_class_learned_memory(
                model, pred_labeled_features_all, labels_down_all, num_classes,
                feature_memory.memory)

            loss = loss + loss_contr_labeled * 0.2
            '''
            CONTRASTIVE LEARNING ON UNLABELED DATA. align unlabeled features to labeled features
            '''

            joined_pseudolabels_down = nn.functional.interpolate(
                joined_pseudolabels.float().unsqueeze(1),
                size=(features_joined_unlabeled.shape[2],
                      features_joined_unlabeled.shape[3]),
                mode='nearest').squeeze(1)

            # take out the features from black pixels from zooms out and augmetnations (ignore labels on pseduoalebl)
            mask = (joined_pseudolabels_down != ignore_label)

            features_joined_unlabeled = features_joined_unlabeled.permute(
                0, 2, 3, 1)
            features_joined_unlabeled = features_joined_unlabeled[mask, ...]
            joined_pseudolabels_down = joined_pseudolabels_down[mask]

            # get projected features
            proj_feat_unlabeled = model.projection_head(
                features_joined_unlabeled)
            pred_feat_unlabeled = model.prediction_head(proj_feat_unlabeled)

            loss_contr_unlabeled = contrastive_class_to_class_learned_memory(
                model, pred_feat_unlabeled, joined_pseudolabels_down,
                num_classes, feature_memory.memory)

            loss = loss + loss_contr_unlabeled * 0.2

        loss_l_value += loss.item()

        # optimize
        loss.backward()
        optimizer.step()

        m = 1 - (1 - 0.995) * (math.cos(math.pi * i_iter / num_iterations) +
                               1) / 2
        ema_model = update_ema_variables(ema_model=ema_model,
                                         model=model,
                                         alpha_teacher=m,
                                         iteration=i_iter)

        # print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format(i_iter, num_iterations, loss_l_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config)

        if i_iter % val_per_iter == 0 and i_iter != 0:
            print('iter = {0:6d}/{1:6d}'.format(i_iter, num_iterations))

            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       ignore_label=ignore_label,
                                       save_dir=checkpoint_dir,
                                       pretraining=pretraining)
            model.train()

            if mIoU > best_mIoU:
                best_mIoU = mIoU
                if save_teacher:
                    _save_checkpoint(i_iter,
                                     ema_model,
                                     optimizer,
                                     config,
                                     save_best=True)
                else:
                    _save_checkpoint(i_iter,
                                     model,
                                     optimizer,
                                     config,
                                     save_best=True)
                iters_without_improve = 0
            else:
                iters_without_improve += val_per_iter
            '''
            if the performance has not improve in N iterations, try to reload best model to optimize again with a lower LR
            Simulating an iterative training'''
            if iters_without_improve > num_iterations / 5.:
                print('Re-loading a previous best model')
                checkpoint = torch.load(
                    os.path.join(checkpoint_dir, f'best_model.pth'))
                model.load_state_dict(checkpoint['model'])
                ema_model = create_ema_model(model, Res_Deeplab)
                ema_model.train()
                ema_model = ema_model.cuda()
                model.train()
                model = model.cuda()
                iters_without_improve = 0  # reset timer

    _save_checkpoint(num_iterations, model, optimizer, config)

    # FINISH TRAINING, evaluate again
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()

    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    # TRY IMPROVING BEST MODEL WITH EMA MODEL OR UPDATING BN STATS

    # Load best model
    checkpoint = torch.load(os.path.join(checkpoint_dir, f'best_model.pth'))
    model.load_state_dict(checkpoint['model'])
    model = model.cuda()

    model = update_BN_weak_unlabeled_data(model, normalize,
                                          batch_size_unlabeled,
                                          trainloader_remain)
    model.eval()
    mIoU, eval_loss = evaluate(model,
                               dataset,
                               deeplabv2=deeplabv2,
                               ignore_label=ignore_label,
                               save_dir=checkpoint_dir,
                               pretraining=pretraining)
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter, model, optimizer, config, save_best=True)

    print('BEST MIOU')
    print(max(best_mIoU_improved, best_mIoU))

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')