コード例 #1
0
def augment_samples(images,
                    labels,
                    probs,
                    do_classmix,
                    batch_size,
                    ignore_label,
                    weak=False):
    """
    Perform data augmentation

    Args:
        images: BxCxWxH images to augment
        labels:  BxWxH labels to augment
        probs:  BxWxH probability maps to augment
        do_classmix: whether to apply classmix augmentation
        batch_size: batch size
        ignore_label: ignore class value
        weak: whether to perform weak or strong augmentation

    Returns:
        augmented data, augmented labels, augmented probs

    """

    if do_classmix:
        # ClassMix: Get mask for image A
        for image_i in range(batch_size):  # for each image
            classes = torch.unique(
                labels[image_i])  # get unique classes in pseudolabel A
            nclasses = classes.shape[0]

            # remove ignore class
            if ignore_label in classes and len(classes) > 1 and nclasses > 1:
                classes = classes[classes != ignore_label]
                nclasses = nclasses - 1

            if dataset == 'pascal_voc':  # if voc dataaset, remove class 0, background
                if 0 in classes and len(classes) > 1 and nclasses > 1:
                    classes = classes[classes != 0]
                    nclasses = nclasses - 1

            # pick half of the classes randomly
            classes = (classes[torch.Tensor(
                np.random.choice(nclasses,
                                 int(((nclasses - nclasses % 2) / 2) + 1),
                                 replace=False)).long()]).cuda()

            # acumulate masks
            if image_i == 0:
                MixMask = transformmasks.generate_class_mask(
                    labels[image_i], classes).unsqueeze(0).cuda()
            else:
                MixMask = torch.cat((MixMask,
                                     transformmasks.generate_class_mask(
                                         labels[image_i],
                                         classes).unsqueeze(0).cuda()))

        params = {"Mix": MixMask}
    else:
        params = {}

    if weak:
        params["flip"] = random.random() < 0.5
        params["ColorJitter"] = random.random() < 0.2
        params["GaussianBlur"] = random.random() < 0.
        params["Grayscale"] = random.random() < 0.0
        params["Solarize"] = random.random() < 0.0
        if random.random() < 0.5:
            scale = random.uniform(0.75, 1.75)
        else:
            scale = 1
        params["RandomScaleCrop"] = scale

        # Apply strong augmentations to unlabeled images
        image_aug, labels_aug, probs_aug = augmentationTransform(
            params,
            data=images,
            target=labels,
            probs=probs,
            jitter_vale=0.125,
            min_sigma=0.1,
            max_sigma=1.5,
            ignore_label=ignore_label)
    else:
        params["flip"] = random.random() < 0.5
        params["ColorJitter"] = random.random() < 0.8
        params["GaussianBlur"] = random.random() < 0.2
        params["Grayscale"] = random.random() < 0.0
        params["Solarize"] = random.random() < 0.0
        if random.random() < 0.80:
            scale = random.uniform(0.75, 1.75)
        else:
            scale = 1
        params["RandomScaleCrop"] = scale

        # Apply strong augmentations to unlabeled images
        image_aug, labels_aug, probs_aug = augmentationTransform(
            params,
            data=images,
            target=labels,
            probs=probs,
            jitter_vale=0.25,
            min_sigma=0.1,
            max_sigma=1.5,
            ignore_label=ignore_label)

    return image_aug, labels_aug, probs_aug, params
コード例 #2
0
def main():
    torch.cuda.empty_cache()
    print(config)

    best_mIoU = 0

    if consistency_loss == 'CE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(
                CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label),
                device_ids=gpus).cuda()
        else:
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda()
    elif consistency_loss == 'MSE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(MSELoss2d(),
                                                   device_ids=gpus).cuda()
        else:
            unlabeled_loss = MSELoss2d().cuda()

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # Initiate ema-model
    if train_unlabeled:
        ema_model = create_ema_model(model)
        ema_model.train()
        ema_model = ema_model.cuda()
    else:
        ema_model = None

    if len(gpus) > 1:
        if use_sync_batchnorm:
            model = convert_model(model)
            model = DataParallelWithCallback(model, device_ids=gpus)
        else:
            model = torch.nn.DataParallel(model, device_ids=gpus)
    model.train()
    model.cuda()

    cudnn.benchmark = True

    if dataset == 'pascal_voc':
        data_loader = get_loader(dataset)
        data_path = get_data_path(dataset)
        train_dataset = data_loader(data_path,
                                    crop_size=input_size,
                                    scale=random_scale,
                                    mirror=random_flip)

    elif dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        if random_crop:
            data_aug = Compose([RandomCrop_city(input_size)])
        else:
            data_aug = None

        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        np.random.seed(random_seed)
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    if train_unlabeled:
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=1,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    if optimizer_type == 'SGD':
        if len(gpus) > 1:
            optimizer = optim.SGD(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.SGD(
                model.optim_parameters(
                    learning_rate_object),  ## DOES THIS CAUSE THE USERWARNING?
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)

    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    start_iteration = 0

    if args.resume:
        start_iteration, model, optimizer, ema_model = _resume_checkpoint(
            args.resume, model, optimizer, ema_model)

    accumulated_loss_l = []
    if train_unlabeled:
        accumulated_loss_u = []

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    epochs_since_start = 0
    for i_iter in range(start_iteration, num_iterations):
        model.train()

        loss_l_value = 0
        if train_unlabeled:
            loss_u_value = 0

        optimizer.zero_grad()

        if lr_schedule:
            adjust_learning_rate(optimizer, i_iter)

        # Training loss for labeled data only
        try:
            batch = next(trainloader_iter)
            if batch[0].shape[0] != batch_size:
                batch = next(trainloader_iter)
        except:
            epochs_since_start = epochs_since_start + 1
            print('Epochs since start: ', epochs_since_start)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        weak_parameters = {"flip": 0}

        images, labels, _, _, _ = batch
        images = images.cuda()
        labels = labels.cuda()

        images, labels = weakTransform(weak_parameters,
                                       data=images,
                                       target=labels)
        intermediary_var = model(images)
        pred = interp(intermediary_var)

        L_l = loss_calc(pred, labels)

        if train_unlabeled:
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            images_remain, _, _, _, _ = batch_remain
            images_remain = images_remain.cuda()
            inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain)
            logits_u_w = interp(ema_model(inputs_u_w))
            logits_u_w, _ = weakTransform(
                getWeakInverseTransformParameters(weak_parameters),
                data=logits_u_w.detach())

            softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1)
            max_probs, argmax_u_w = torch.max(softmax_u_w, dim=1)

            if mix_mask == "class":

                for image_i in range(batch_size):
                    classes = torch.unique(argmax_u_w[image_i])
                    classes = classes[classes != ignore_label]
                    nclasses = classes.shape[0]
                    classes = (classes[torch.Tensor(
                        np.random.choice(nclasses,
                                         int((nclasses - nclasses % 2) / 2),
                                         replace=False)).long()]).cuda()
                    if image_i == 0:
                        MixMask = transformmasks.generate_class_mask(
                            argmax_u_w[image_i], classes).unsqueeze(0).cuda()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             transformmasks.generate_class_mask(
                                 argmax_u_w[image_i],
                                 classes).unsqueeze(0).cuda()))

            elif mix_mask == 'cut':
                img_size = inputs_u_w.shape[2:4]
                for image_i in range(batch_size):
                    if image_i == 0:
                        MixMask = torch.from_numpy(
                            transformmasks.generate_cutout_mask(
                                img_size)).unsqueeze(0).cuda().float()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             torch.from_numpy(
                                 transformmasks.generate_cutout_mask(
                                     img_size)).unsqueeze(0).cuda().float()))

            elif mix_mask == "cow":
                img_size = inputs_u_w.shape[2:4]
                sigma_min = 8
                sigma_max = 32
                p_min = 0.5
                p_max = 0.5
                for image_i in range(batch_size):
                    sigma = np.exp(
                        np.random.uniform(np.log(sigma_min),
                                          np.log(sigma_max)))  # Random sigma
                    p = np.random.uniform(p_min, p_max)  # Random p
                    if image_i == 0:
                        MixMask = torch.from_numpy(
                            transformmasks.generate_cow_mask(
                                img_size, sigma, p,
                                seed=None)).unsqueeze(0).cuda().float()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             torch.from_numpy(
                                 transformmasks.generate_cow_mask(
                                     img_size, sigma, p,
                                     seed=None)).unsqueeze(0).cuda().float()))

            elif mix_mask == None:
                MixMask = torch.ones((inputs_u_w.shape)).cuda()

            strong_parameters = {"Mix": MixMask}
            if random_flip:
                strong_parameters["flip"] = random.randint(0, 1)
            else:
                strong_parameters["flip"] = 0
            if color_jitter:
                strong_parameters["ColorJitter"] = random.uniform(0, 1)
            else:
                strong_parameters["ColorJitter"] = 0
            if gaussian_blur:
                strong_parameters["GaussianBlur"] = random.uniform(0, 1)
            else:
                strong_parameters["GaussianBlur"] = 0

            inputs_u_s, _ = strongTransform(strong_parameters,
                                            data=images_remain)
            logits_u_s = interp(model(inputs_u_s))

            softmax_u_w_mixed, _ = strongTransform(strong_parameters,
                                                   data=softmax_u_w)
            max_probs, pseudo_label = torch.max(softmax_u_w_mixed, dim=1)

            if pixel_weight == "threshold_uniform":
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(pseudo_label.cpu()))
                pixelWiseWeight = unlabeled_weight * torch.ones(
                    max_probs.shape).cuda()
            elif pixel_weight == "threshold":
                pixelWiseWeight = max_probs.ge(0.968).long().cuda()
            elif pixel_weight == 'sigmoid':
                max_iter = 10000
                pixelWiseWeight = sigmoid_ramp_up(
                    i_iter, max_iter) * torch.ones(max_probs.shape).cuda()
            elif pixel_weight == False:
                pixelWiseWeight = torch.ones(max_probs.shape).cuda()

            if consistency_loss == 'CE':
                L_u = consistency_weight * unlabeled_loss(
                    logits_u_s, pseudo_label, pixelWiseWeight)
            elif consistency_loss == 'MSE':
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(pseudo_label.cpu()))
                #softmax_u_w_mixed = torch.cat((softmax_u_w_mixed[1].unsqueeze(0),softmax_u_w_mixed[0].unsqueeze(0)))
                L_u = consistency_weight * unlabeled_weight * unlabeled_loss(
                    logits_u_s, softmax_u_w_mixed)

            loss = L_l + L_u

        else:
            loss = L_l

        if len(gpus) > 1:
            loss = loss.mean()
            loss_l_value += L_l.mean().item()
            if train_unlabeled:
                loss_u_value += L_u.mean().item()
        else:
            loss_l_value += L_l.item()
            if train_unlabeled:
                loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()

        # update Mean teacher network
        if ema_model is not None:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model=ema_model,
                                             model=model,
                                             alpha_teacher=alpha_teacher,
                                             iteration=i_iter)

        if train_unlabeled:
            print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.
                  format(i_iter, num_iterations, loss_l_value, loss_u_value))
        else:
            print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format(
                i_iter, num_iterations, loss_l_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config, ema_model)

        if use_tensorboard:
            if 'tensorboard_writer' not in locals():
                tensorboard_writer = tensorboard.SummaryWriter(log_dir,
                                                               flush_secs=30)

            accumulated_loss_l.append(loss_l_value)
            if train_unlabeled:
                accumulated_loss_u.append(loss_u_value)
            if i_iter % log_per_iter == 0 and i_iter != 0:

                tensorboard_writer.add_scalar('Training/Supervised loss',
                                              np.mean(accumulated_loss_l),
                                              i_iter)
                accumulated_loss_l = []

                if train_unlabeled:
                    tensorboard_writer.add_scalar('Training/Unsupervised loss',
                                                  np.mean(accumulated_loss_u),
                                                  i_iter)
                    accumulated_loss_u = []

        if i_iter % val_per_iter == 0 and i_iter != 0:
            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       ignore_label=ignore_label,
                                       input_size=(512, 1024),
                                       save_dir=checkpoint_dir)

            model.train()

            if mIoU > best_mIoU and save_best_model:
                best_mIoU = mIoU
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 save_best=True)

            if use_tensorboard:
                tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
                tensorboard_writer.add_scalar('Validation/Loss', eval_loss,
                                              i_iter)

        if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0:
            # Saves two mixed images and the corresponding prediction
            save_image(inputs_u_s[0].cpu(), i_iter, 'input1',
                       palette.CityScpates_palette)
            save_image(inputs_u_s[1].cpu(), i_iter, 'input2',
                       palette.CityScpates_palette)
            _, pred_u_s = torch.max(logits_u_s, dim=1)
            save_image(pred_u_s[0].cpu(), i_iter, 'pred1',
                       palette.CityScpates_palette)
            save_image(pred_u_s[1].cpu(), i_iter, 'pred2',
                       palette.CityScpates_palette)

    _save_checkpoint(num_iterations, model, optimizer, config, ema_model)

    model.eval()
    mIoU, val_loss = evaluate(model,
                              dataset,
                              ignore_label=ignore_label,
                              input_size=(512, 1024),
                              save_dir=checkpoint_dir)

    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter,
                         model,
                         optimizer,
                         config,
                         ema_model,
                         save_best=True)

    if use_tensorboard:
        tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
        tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter)

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')
コード例 #3
0
def main():
    print(config)

    best_mIoU = 0

    if consistency_loss == 'MSE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(MSELoss2d(),
                                                   device_ids=gpus).cuda()
        else:
            unlabeled_loss = MSELoss2d().cuda()
    elif consistency_loss == 'CE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(
                CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label),
                device_ids=gpus).cuda()
        else:
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
                ignore_index=ignore_label).cuda()

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    #saved_state_dict = torch.load(args.restore_from)
    # load pretrained parameters
    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # init ema-model
    if train_unlabeled:
        ema_model = create_ema_model(model)
        ema_model.train()
        ema_model = ema_model.cuda()
    else:
        ema_model = None

    if len(gpus) > 1:
        if use_sync_batchnorm:
            model = convert_model(model)
            model = DataParallelWithCallback(model, device_ids=gpus)
        else:
            model = torch.nn.DataParallel(model, device_ids=gpus)
    model.train()
    model.cuda()

    cudnn.benchmark = True
    data_loader = get_loader(config['dataset'])
    # data_path = get_data_path(config['dataset'])
    # if random_crop:
    # data_aug = Compose([RandomCrop_city(input_size)])
    # else:
    # data_aug = None
    data_aug = Compose([RandomHorizontallyFlip()])
    if dataset == 'cityscapes':
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)
    elif dataset == 'multiview':
        # adaption data
        data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation'
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    view_idx=0,
                                    number_views=1,
                                    load_seg_mask=False,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    if labeled_samples is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True)

        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    else:
        partial_size = labeled_samples
        print('Training on number of samples:', partial_size)
        np.random.seed(random_seed)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    #New loader for Domain transfer
    # if random_crop:
    # data_aug = Compose([RandomCrop_gta(input_size)])
    # else:
    # data_aug = None
    # SUPERVSIED DATA
    data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation'
    data_aug = Compose([RandomHorizontallyFlip()])
    if dataset == 'multiview':
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    view_idx=0,
                                    number_views=1,
                                    load_seg_mask=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)
    else:
        data_loader = get_loader('gta')
        data_path = get_data_path('gta')
        train_dataset = data_loader(data_path,
                                    list_path='./data/gta5_list/train.txt',
                                    augmentations=data_aug,
                                    img_size=(1280, 720),
                                    mean=IMG_MEAN)

    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True)

    # training loss for labeled data only
    trainloader_iter = iter(trainloader)
    print('gta size:', len(trainloader))

    #Load new data for domain_transfer

    # optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    if optimizer_type == 'SGD':
        if len(gpus) > 1:
            optimizer = optim.SGD(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.SGD(model.optim_parameters(learning_rate_object),
                                  lr=learning_rate,
                                  momentum=momentum,
                                  weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        if len(gpus) > 1:
            optimizer = optim.Adam(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.Adam(
                model.optim_parameters(learning_rate_object),
                lr=learning_rate,
                weight_decay=weight_decay)

    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)
    start_iteration = 0

    if args.resume:
        start_iteration, model, optimizer, ema_model = _resume_checkpoint(
            args.resume, model, optimizer, ema_model)

    accumulated_loss_l = []
    accumulated_loss_u = []

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=True)

    epochs_since_start = 0
    for i_iter in range(start_iteration, num_iterations):
        model.train()

        loss_u_value = 0
        loss_l_value = 0

        optimizer.zero_grad()

        if lr_schedule:
            adjust_learning_rate(optimizer, i_iter)

        # training loss for labeled data only
        try:
            batch = next(trainloader_iter)
            if batch[0].shape[0] != batch_size:
                batch = next(trainloader_iter)
        except:
            epochs_since_start = epochs_since_start + 1
            print('Epochs since start: ', epochs_since_start)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        #if random_flip:
        #    weak_parameters={"flip":random.randint(0,1)}
        #else:
        weak_parameters = {"flip": 0}

        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.cuda().long()

        #images, labels = weakTransform(weak_parameters, data = images, target = labels)

        pred = interp(model(images))
        L_l = loss_calc(pred, labels)  # Cross entropy loss for labeled data
        #L_l = torch.Tensor([0.0]).cuda()

        if train_unlabeled:
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            images_remain, *_ = batch_remain
            images_remain = images_remain.cuda()
            inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain)
            #inputs_u_w = inputs_u_w.clone()
            logits_u_w = interp(ema_model(inputs_u_w))
            logits_u_w, _ = weakTransform(
                getWeakInverseTransformParameters(weak_parameters),
                data=logits_u_w.detach())

            pseudo_label = torch.softmax(logits_u_w.detach(), dim=1)
            max_probs, targets_u_w = torch.max(pseudo_label, dim=1)

            if mix_mask == "class":
                for image_i in range(batch_size):
                    classes = torch.unique(labels[image_i])
                    #classes=classes[classes!=ignore_label]
                    nclasses = classes.shape[0]
                    #if nclasses > 0:
                    classes = (classes[torch.Tensor(
                        np.random.choice(nclasses,
                                         int((nclasses + nclasses % 2) / 2),
                                         replace=False)).long()]).cuda()

                    if image_i == 0:
                        MixMask0 = transformmasks.generate_class_mask(
                            labels[image_i], classes).unsqueeze(0).cuda()
                    else:
                        MixMask1 = transformmasks.generate_class_mask(
                            labels[image_i], classes).unsqueeze(0).cuda()

            elif mix_mask == None:
                MixMask = torch.ones((inputs_u_w.shape))

            strong_parameters = {"Mix": MixMask0}
            if random_flip:
                strong_parameters["flip"] = random.randint(0, 1)
            else:
                strong_parameters["flip"] = 0
            if color_jitter:
                strong_parameters["ColorJitter"] = random.uniform(0, 1)
            else:
                strong_parameters["ColorJitter"] = 0
            if gaussian_blur:
                strong_parameters["GaussianBlur"] = random.uniform(0, 1)
            else:
                strong_parameters["GaussianBlur"] = 0

            inputs_u_s0, _ = strongTransform(
                strong_parameters,
                data=torch.cat(
                    (images[0].unsqueeze(0), images_remain[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            inputs_u_s1, _ = strongTransform(
                strong_parameters,
                data=torch.cat(
                    (images[1].unsqueeze(0), images_remain[1].unsqueeze(0))))
            inputs_u_s = torch.cat((inputs_u_s0, inputs_u_s1))
            logits_u_s = interp(model(inputs_u_s))

            strong_parameters["Mix"] = MixMask0
            _, targets_u0 = strongTransform(strong_parameters,
                                            target=torch.cat(
                                                (labels[0].unsqueeze(0),
                                                 targets_u_w[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            _, targets_u1 = strongTransform(strong_parameters,
                                            target=torch.cat(
                                                (labels[1].unsqueeze(0),
                                                 targets_u_w[1].unsqueeze(0))))
            targets_u = torch.cat((targets_u0, targets_u1)).long()

            if pixel_weight == "threshold_uniform":
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(targets_u.cpu()))
                pixelWiseWeight = unlabeled_weight * torch.ones(
                    max_probs.shape).cuda()
            elif pixel_weight == "threshold":
                pixelWiseWeight = max_probs.ge(0.968).float().cuda()
            elif pixel_weight == False:
                pixelWiseWeight = torch.ones(max_probs.shape).cuda()

            onesWeights = torch.ones((pixelWiseWeight.shape)).cuda()
            strong_parameters["Mix"] = MixMask0
            _, pixelWiseWeight0 = strongTransform(
                strong_parameters,
                target=torch.cat((onesWeights[0].unsqueeze(0),
                                  pixelWiseWeight[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            _, pixelWiseWeight1 = strongTransform(
                strong_parameters,
                target=torch.cat((onesWeights[1].unsqueeze(0),
                                  pixelWiseWeight[1].unsqueeze(0))))
            pixelWiseWeight = torch.cat(
                (pixelWiseWeight0, pixelWiseWeight1)).cuda()

            if consistency_loss == 'MSE':
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(targets_u.cpu()))
                #pseudo_label = torch.cat((pseudo_label[1].unsqueeze(0),pseudo_label[0].unsqueeze(0)))
                L_u = consistency_weight * unlabeled_weight * unlabeled_loss(
                    logits_u_s, pseudo_label)
            elif consistency_loss == 'CE':
                L_u = consistency_weight * unlabeled_loss(
                    logits_u_s, targets_u, pixelWiseWeight)

            loss = L_l + L_u

        else:
            loss = L_l

        if len(gpus) > 1:
            #print('before mean = ',loss)
            loss = loss.mean()
            #print('after mean = ',loss)
            loss_l_value += L_l.mean().item()
            if train_unlabeled:
                loss_u_value += L_u.mean().item()
        else:
            loss_l_value += L_l.item()
            if train_unlabeled:
                loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()

        # update Mean teacher network
        if ema_model is not None:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model=ema_model,
                                             model=model,
                                             alpha_teacher=alpha_teacher,
                                             iteration=i_iter)

        print(
            'iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.format(
                i_iter, num_iterations, loss_l_value, loss_u_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            if epochs_since_start * len(trainloader) < save_checkpoint_every:
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 overwrite=False)
            else:
                _save_checkpoint(i_iter, model, optimizer, config, ema_model)

        if config['utils']['tensorboard']:
            if 'tensorboard_writer' not in locals():
                tensorboard_writer = tensorboard.SummaryWriter(log_dir,
                                                               flush_secs=30)

            accumulated_loss_l.append(loss_l_value)
            if train_unlabeled:
                accumulated_loss_u.append(loss_u_value)
            if i_iter % log_per_iter == 0 and i_iter != 0:

                tensorboard_writer.add_scalar('Training/Supervised loss',
                                              np.mean(accumulated_loss_l),
                                              i_iter)
                accumulated_loss_l = []

                if train_unlabeled:
                    tensorboard_writer.add_scalar('Training/Unsupervised loss',
                                                  np.mean(accumulated_loss_u),
                                                  i_iter)
                    accumulated_loss_u = []

        if i_iter % val_per_iter == 0 and i_iter != 0:
            model.eval()
            if dataset == 'cityscapes':
                mIoU, eval_loss = evaluate(model,
                                           dataset,
                                           ignore_label=250,
                                           input_size=(512, 1024),
                                           save_dir=checkpoint_dir)
            elif dataset == 'multiview':
                mIoU, eval_loss = evaluate(model,
                                           dataset,
                                           ignore_label=255,
                                           input_size=(300, 300),
                                           save_dir=checkpoint_dir)
            else:
                print('erro dataset: {}'.format(dataset))
            model.train()

            if mIoU > best_mIoU and save_best_model:
                best_mIoU = mIoU
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 save_best=True)

            if config['utils']['tensorboard']:
                tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
                tensorboard_writer.add_scalar('Validation/Loss', eval_loss,
                                              i_iter)
                print('iter {}, mIoU: {}'.format(mIoU, i_iter))

        if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0:
            # Saves two mixed images and the corresponding prediction
            save_image(inputs_u_s[0].cpu(), i_iter, 'input1',
                       palette.CityScpates_palette)
            save_image(inputs_u_s[1].cpu(), i_iter, 'input2',
                       palette.CityScpates_palette)
            _, pred_u_s = torch.max(logits_u_s, dim=1)
            save_image(pred_u_s[0].cpu(), i_iter, 'pred1',
                       palette.CityScpates_palette)
            save_image(pred_u_s[1].cpu(), i_iter, 'pred2',
                       palette.CityScpates_palette)

    _save_checkpoint(num_iterations, model, optimizer, config, ema_model)

    model.eval()
    if dataset == 'cityscapes':
        mIoU, val_loss = evaluate(model,
                                  dataset,
                                  ignore_label=250,
                                  input_size=(512, 1024),
                                  save_dir=checkpoint_dir)
    elif dataset == 'multiview':
        mIoU, val_loss = evaluate(model,
                                  dataset,
                                  ignore_label=255,
                                  input_size=(300, 300),
                                  save_dir=checkpoint_dir)
    else:
        print('erro dataset: {}'.format(dataset))
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter,
                         model,
                         optimizer,
                         config,
                         ema_model,
                         save_best=True)

    if config['utils']['tensorboard']:
        tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
        tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter)

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + 'seconds')