def augment_samples(images, labels, probs, do_classmix, batch_size, ignore_label, weak=False): """ Perform data augmentation Args: images: BxCxWxH images to augment labels: BxWxH labels to augment probs: BxWxH probability maps to augment do_classmix: whether to apply classmix augmentation batch_size: batch size ignore_label: ignore class value weak: whether to perform weak or strong augmentation Returns: augmented data, augmented labels, augmented probs """ if do_classmix: # ClassMix: Get mask for image A for image_i in range(batch_size): # for each image classes = torch.unique( labels[image_i]) # get unique classes in pseudolabel A nclasses = classes.shape[0] # remove ignore class if ignore_label in classes and len(classes) > 1 and nclasses > 1: classes = classes[classes != ignore_label] nclasses = nclasses - 1 if dataset == 'pascal_voc': # if voc dataaset, remove class 0, background if 0 in classes and len(classes) > 1 and nclasses > 1: classes = classes[classes != 0] nclasses = nclasses - 1 # pick half of the classes randomly classes = (classes[torch.Tensor( np.random.choice(nclasses, int(((nclasses - nclasses % 2) / 2) + 1), replace=False)).long()]).cuda() # acumulate masks if image_i == 0: MixMask = transformmasks.generate_class_mask( labels[image_i], classes).unsqueeze(0).cuda() else: MixMask = torch.cat((MixMask, transformmasks.generate_class_mask( labels[image_i], classes).unsqueeze(0).cuda())) params = {"Mix": MixMask} else: params = {} if weak: params["flip"] = random.random() < 0.5 params["ColorJitter"] = random.random() < 0.2 params["GaussianBlur"] = random.random() < 0. params["Grayscale"] = random.random() < 0.0 params["Solarize"] = random.random() < 0.0 if random.random() < 0.5: scale = random.uniform(0.75, 1.75) else: scale = 1 params["RandomScaleCrop"] = scale # Apply strong augmentations to unlabeled images image_aug, labels_aug, probs_aug = augmentationTransform( params, data=images, target=labels, probs=probs, jitter_vale=0.125, min_sigma=0.1, max_sigma=1.5, ignore_label=ignore_label) else: params["flip"] = random.random() < 0.5 params["ColorJitter"] = random.random() < 0.8 params["GaussianBlur"] = random.random() < 0.2 params["Grayscale"] = random.random() < 0.0 params["Solarize"] = random.random() < 0.0 if random.random() < 0.80: scale = random.uniform(0.75, 1.75) else: scale = 1 params["RandomScaleCrop"] = scale # Apply strong augmentations to unlabeled images image_aug, labels_aug, probs_aug = augmentationTransform( params, data=images, target=labels, probs=probs, jitter_vale=0.25, min_sigma=0.1, max_sigma=1.5, ignore_label=ignore_label) return image_aug, labels_aug, probs_aug, params
def main(): torch.cuda.empty_cache() print(config) best_mIoU = 0 if consistency_loss == 'CE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel( CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label), device_ids=gpus).cuda() else: unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda() elif consistency_loss == 'MSE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel(MSELoss2d(), device_ids=gpus).cuda() else: unlabeled_loss = MSELoss2d().cuda() cudnn.enabled = True # create network model = Res_Deeplab(num_classes=num_classes) # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # Copy loaded parameters to model new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) # Initiate ema-model if train_unlabeled: ema_model = create_ema_model(model) ema_model.train() ema_model = ema_model.cuda() else: ema_model = None if len(gpus) > 1: if use_sync_batchnorm: model = convert_model(model) model = DataParallelWithCallback(model, device_ids=gpus) else: model = torch.nn.DataParallel(model, device_ids=gpus) model.train() model.cuda() cudnn.benchmark = True if dataset == 'pascal_voc': data_loader = get_loader(dataset) data_path = get_data_path(dataset) train_dataset = data_loader(data_path, crop_size=input_size, scale=random_scale, mirror=random_flip) elif dataset == 'cityscapes': data_loader = get_loader('cityscapes') data_path = get_data_path('cityscapes') if random_crop: data_aug = Compose([RandomCrop_city(input_size)]) else: data_aug = None train_dataset = data_loader(data_path, is_transform=True, augmentations=data_aug, img_size=input_size) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) partial_size = labeled_samples print('Training on number of samples:', partial_size) if split_id is not None: train_ids = pickle.load(open(split_id, 'rb')) print('loading train ids from {}'.format(split_id)) else: np.random.seed(random_seed) train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True) trainloader_iter = iter(trainloader) if train_unlabeled: train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_remain_sampler, num_workers=1, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) # Optimizer for segmentation network learning_rate_object = Learning_Rate_Object( config['training']['learning_rate']) if optimizer_type == 'SGD': if len(gpus) > 1: optimizer = optim.SGD( model.module.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) else: optimizer = optim.SGD( model.optim_parameters( learning_rate_object), ## DOES THIS CAUSE THE USERWARNING? lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer.zero_grad() interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) start_iteration = 0 if args.resume: start_iteration, model, optimizer, ema_model = _resume_checkpoint( args.resume, model, optimizer, ema_model) accumulated_loss_l = [] if train_unlabeled: accumulated_loss_u = [] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) with open(checkpoint_dir + '/config.json', 'w') as handle: json.dump(config, handle, indent=4, sort_keys=False) pickle.dump(train_ids, open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb')) epochs_since_start = 0 for i_iter in range(start_iteration, num_iterations): model.train() loss_l_value = 0 if train_unlabeled: loss_u_value = 0 optimizer.zero_grad() if lr_schedule: adjust_learning_rate(optimizer, i_iter) # Training loss for labeled data only try: batch = next(trainloader_iter) if batch[0].shape[0] != batch_size: batch = next(trainloader_iter) except: epochs_since_start = epochs_since_start + 1 print('Epochs since start: ', epochs_since_start) trainloader_iter = iter(trainloader) batch = next(trainloader_iter) weak_parameters = {"flip": 0} images, labels, _, _, _ = batch images = images.cuda() labels = labels.cuda() images, labels = weakTransform(weak_parameters, data=images, target=labels) intermediary_var = model(images) pred = interp(intermediary_var) L_l = loss_calc(pred, labels) if train_unlabeled: try: batch_remain = next(trainloader_remain_iter) if batch_remain[0].shape[0] != batch_size: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) images_remain, _, _, _, _ = batch_remain images_remain = images_remain.cuda() inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain) logits_u_w = interp(ema_model(inputs_u_w)) logits_u_w, _ = weakTransform( getWeakInverseTransformParameters(weak_parameters), data=logits_u_w.detach()) softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1) max_probs, argmax_u_w = torch.max(softmax_u_w, dim=1) if mix_mask == "class": for image_i in range(batch_size): classes = torch.unique(argmax_u_w[image_i]) classes = classes[classes != ignore_label] nclasses = classes.shape[0] classes = (classes[torch.Tensor( np.random.choice(nclasses, int((nclasses - nclasses % 2) / 2), replace=False)).long()]).cuda() if image_i == 0: MixMask = transformmasks.generate_class_mask( argmax_u_w[image_i], classes).unsqueeze(0).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_class_mask( argmax_u_w[image_i], classes).unsqueeze(0).cuda())) elif mix_mask == 'cut': img_size = inputs_u_w.shape[2:4] for image_i in range(batch_size): if image_i == 0: MixMask = torch.from_numpy( transformmasks.generate_cutout_mask( img_size)).unsqueeze(0).cuda().float() else: MixMask = torch.cat( (MixMask, torch.from_numpy( transformmasks.generate_cutout_mask( img_size)).unsqueeze(0).cuda().float())) elif mix_mask == "cow": img_size = inputs_u_w.shape[2:4] sigma_min = 8 sigma_max = 32 p_min = 0.5 p_max = 0.5 for image_i in range(batch_size): sigma = np.exp( np.random.uniform(np.log(sigma_min), np.log(sigma_max))) # Random sigma p = np.random.uniform(p_min, p_max) # Random p if image_i == 0: MixMask = torch.from_numpy( transformmasks.generate_cow_mask( img_size, sigma, p, seed=None)).unsqueeze(0).cuda().float() else: MixMask = torch.cat( (MixMask, torch.from_numpy( transformmasks.generate_cow_mask( img_size, sigma, p, seed=None)).unsqueeze(0).cuda().float())) elif mix_mask == None: MixMask = torch.ones((inputs_u_w.shape)).cuda() strong_parameters = {"Mix": MixMask} if random_flip: strong_parameters["flip"] = random.randint(0, 1) else: strong_parameters["flip"] = 0 if color_jitter: strong_parameters["ColorJitter"] = random.uniform(0, 1) else: strong_parameters["ColorJitter"] = 0 if gaussian_blur: strong_parameters["GaussianBlur"] = random.uniform(0, 1) else: strong_parameters["GaussianBlur"] = 0 inputs_u_s, _ = strongTransform(strong_parameters, data=images_remain) logits_u_s = interp(model(inputs_u_s)) softmax_u_w_mixed, _ = strongTransform(strong_parameters, data=softmax_u_w) max_probs, pseudo_label = torch.max(softmax_u_w_mixed, dim=1) if pixel_weight == "threshold_uniform": unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(pseudo_label.cpu())) pixelWiseWeight = unlabeled_weight * torch.ones( max_probs.shape).cuda() elif pixel_weight == "threshold": pixelWiseWeight = max_probs.ge(0.968).long().cuda() elif pixel_weight == 'sigmoid': max_iter = 10000 pixelWiseWeight = sigmoid_ramp_up( i_iter, max_iter) * torch.ones(max_probs.shape).cuda() elif pixel_weight == False: pixelWiseWeight = torch.ones(max_probs.shape).cuda() if consistency_loss == 'CE': L_u = consistency_weight * unlabeled_loss( logits_u_s, pseudo_label, pixelWiseWeight) elif consistency_loss == 'MSE': unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(pseudo_label.cpu())) #softmax_u_w_mixed = torch.cat((softmax_u_w_mixed[1].unsqueeze(0),softmax_u_w_mixed[0].unsqueeze(0))) L_u = consistency_weight * unlabeled_weight * unlabeled_loss( logits_u_s, softmax_u_w_mixed) loss = L_l + L_u else: loss = L_l if len(gpus) > 1: loss = loss.mean() loss_l_value += L_l.mean().item() if train_unlabeled: loss_u_value += L_u.mean().item() else: loss_l_value += L_l.item() if train_unlabeled: loss_u_value += L_u.item() loss.backward() optimizer.step() # update Mean teacher network if ema_model is not None: alpha_teacher = 0.99 ema_model = update_ema_variables(ema_model=ema_model, model=model, alpha_teacher=alpha_teacher, iteration=i_iter) if train_unlabeled: print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'. format(i_iter, num_iterations, loss_l_value, loss_u_value)) else: print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format( i_iter, num_iterations, loss_l_value)) if i_iter % save_checkpoint_every == 0 and i_iter != 0: _save_checkpoint(i_iter, model, optimizer, config, ema_model) if use_tensorboard: if 'tensorboard_writer' not in locals(): tensorboard_writer = tensorboard.SummaryWriter(log_dir, flush_secs=30) accumulated_loss_l.append(loss_l_value) if train_unlabeled: accumulated_loss_u.append(loss_u_value) if i_iter % log_per_iter == 0 and i_iter != 0: tensorboard_writer.add_scalar('Training/Supervised loss', np.mean(accumulated_loss_l), i_iter) accumulated_loss_l = [] if train_unlabeled: tensorboard_writer.add_scalar('Training/Unsupervised loss', np.mean(accumulated_loss_u), i_iter) accumulated_loss_u = [] if i_iter % val_per_iter == 0 and i_iter != 0: model.eval() mIoU, eval_loss = evaluate(model, dataset, ignore_label=ignore_label, input_size=(512, 1024), save_dir=checkpoint_dir) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if use_tensorboard: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', eval_loss, i_iter) if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0: # Saves two mixed images and the corresponding prediction save_image(inputs_u_s[0].cpu(), i_iter, 'input1', palette.CityScpates_palette) save_image(inputs_u_s[1].cpu(), i_iter, 'input2', palette.CityScpates_palette) _, pred_u_s = torch.max(logits_u_s, dim=1) save_image(pred_u_s[0].cpu(), i_iter, 'pred1', palette.CityScpates_palette) save_image(pred_u_s[1].cpu(), i_iter, 'pred2', palette.CityScpates_palette) _save_checkpoint(num_iterations, model, optimizer, config, ema_model) model.eval() mIoU, val_loss = evaluate(model, dataset, ignore_label=ignore_label, input_size=(512, 1024), save_dir=checkpoint_dir) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if use_tensorboard: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter) end = timeit.default_timer() print('Total time: ' + str(end - start) + ' seconds')
def main(): print(config) best_mIoU = 0 if consistency_loss == 'MSE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel(MSELoss2d(), device_ids=gpus).cuda() else: unlabeled_loss = MSELoss2d().cuda() elif consistency_loss == 'CE': if len(gpus) > 1: unlabeled_loss = torch.nn.DataParallel( CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label), device_ids=gpus).cuda() else: unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted( ignore_index=ignore_label).cuda() cudnn.enabled = True # create network model = Res_Deeplab(num_classes=num_classes) # load pretrained parameters #saved_state_dict = torch.load(args.restore_from) # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # Copy loaded parameters to model new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) # init ema-model if train_unlabeled: ema_model = create_ema_model(model) ema_model.train() ema_model = ema_model.cuda() else: ema_model = None if len(gpus) > 1: if use_sync_batchnorm: model = convert_model(model) model = DataParallelWithCallback(model, device_ids=gpus) else: model = torch.nn.DataParallel(model, device_ids=gpus) model.train() model.cuda() cudnn.benchmark = True data_loader = get_loader(config['dataset']) # data_path = get_data_path(config['dataset']) # if random_crop: # data_aug = Compose([RandomCrop_city(input_size)]) # else: # data_aug = None data_aug = Compose([RandomHorizontallyFlip()]) if dataset == 'cityscapes': train_dataset = data_loader(data_path, is_transform=True, augmentations=data_aug, img_size=input_size, img_mean=IMG_MEAN) elif dataset == 'multiview': # adaption data data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation' train_dataset = data_loader(data_path, is_transform=True, view_idx=0, number_views=1, load_seg_mask=False, augmentations=data_aug, img_size=input_size, img_mean=IMG_MEAN) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) if labeled_samples is None: trainloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) else: partial_size = labeled_samples print('Training on number of samples:', partial_size) np.random.seed(random_seed) trainloader_remain = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) #New loader for Domain transfer # if random_crop: # data_aug = Compose([RandomCrop_gta(input_size)]) # else: # data_aug = None # SUPERVSIED DATA data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation' data_aug = Compose([RandomHorizontallyFlip()]) if dataset == 'multiview': train_dataset = data_loader(data_path, is_transform=True, view_idx=0, number_views=1, load_seg_mask=True, augmentations=data_aug, img_size=input_size, img_mean=IMG_MEAN) else: data_loader = get_loader('gta') data_path = get_data_path('gta') train_dataset = data_loader(data_path, list_path='./data/gta5_list/train.txt', augmentations=data_aug, img_size=(1280, 720), mean=IMG_MEAN) trainloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) # training loss for labeled data only trainloader_iter = iter(trainloader) print('gta size:', len(trainloader)) #Load new data for domain_transfer # optimizer for segmentation network learning_rate_object = Learning_Rate_Object( config['training']['learning_rate']) if optimizer_type == 'SGD': if len(gpus) > 1: optimizer = optim.SGD( model.module.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) else: optimizer = optim.SGD(model.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) elif optimizer_type == 'Adam': if len(gpus) > 1: optimizer = optim.Adam( model.module.optim_parameters(learning_rate_object), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) else: optimizer = optim.Adam( model.optim_parameters(learning_rate_object), lr=learning_rate, weight_decay=weight_decay) optimizer.zero_grad() interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) start_iteration = 0 if args.resume: start_iteration, model, optimizer, ema_model = _resume_checkpoint( args.resume, model, optimizer, ema_model) accumulated_loss_l = [] accumulated_loss_u = [] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) with open(checkpoint_dir + '/config.json', 'w') as handle: json.dump(config, handle, indent=4, sort_keys=True) epochs_since_start = 0 for i_iter in range(start_iteration, num_iterations): model.train() loss_u_value = 0 loss_l_value = 0 optimizer.zero_grad() if lr_schedule: adjust_learning_rate(optimizer, i_iter) # training loss for labeled data only try: batch = next(trainloader_iter) if batch[0].shape[0] != batch_size: batch = next(trainloader_iter) except: epochs_since_start = epochs_since_start + 1 print('Epochs since start: ', epochs_since_start) trainloader_iter = iter(trainloader) batch = next(trainloader_iter) #if random_flip: # weak_parameters={"flip":random.randint(0,1)} #else: weak_parameters = {"flip": 0} images, labels, _, _ = batch images = images.cuda() labels = labels.cuda().long() #images, labels = weakTransform(weak_parameters, data = images, target = labels) pred = interp(model(images)) L_l = loss_calc(pred, labels) # Cross entropy loss for labeled data #L_l = torch.Tensor([0.0]).cuda() if train_unlabeled: try: batch_remain = next(trainloader_remain_iter) if batch_remain[0].shape[0] != batch_size: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) images_remain, *_ = batch_remain images_remain = images_remain.cuda() inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain) #inputs_u_w = inputs_u_w.clone() logits_u_w = interp(ema_model(inputs_u_w)) logits_u_w, _ = weakTransform( getWeakInverseTransformParameters(weak_parameters), data=logits_u_w.detach()) pseudo_label = torch.softmax(logits_u_w.detach(), dim=1) max_probs, targets_u_w = torch.max(pseudo_label, dim=1) if mix_mask == "class": for image_i in range(batch_size): classes = torch.unique(labels[image_i]) #classes=classes[classes!=ignore_label] nclasses = classes.shape[0] #if nclasses > 0: classes = (classes[torch.Tensor( np.random.choice(nclasses, int((nclasses + nclasses % 2) / 2), replace=False)).long()]).cuda() if image_i == 0: MixMask0 = transformmasks.generate_class_mask( labels[image_i], classes).unsqueeze(0).cuda() else: MixMask1 = transformmasks.generate_class_mask( labels[image_i], classes).unsqueeze(0).cuda() elif mix_mask == None: MixMask = torch.ones((inputs_u_w.shape)) strong_parameters = {"Mix": MixMask0} if random_flip: strong_parameters["flip"] = random.randint(0, 1) else: strong_parameters["flip"] = 0 if color_jitter: strong_parameters["ColorJitter"] = random.uniform(0, 1) else: strong_parameters["ColorJitter"] = 0 if gaussian_blur: strong_parameters["GaussianBlur"] = random.uniform(0, 1) else: strong_parameters["GaussianBlur"] = 0 inputs_u_s0, _ = strongTransform( strong_parameters, data=torch.cat( (images[0].unsqueeze(0), images_remain[0].unsqueeze(0)))) strong_parameters["Mix"] = MixMask1 inputs_u_s1, _ = strongTransform( strong_parameters, data=torch.cat( (images[1].unsqueeze(0), images_remain[1].unsqueeze(0)))) inputs_u_s = torch.cat((inputs_u_s0, inputs_u_s1)) logits_u_s = interp(model(inputs_u_s)) strong_parameters["Mix"] = MixMask0 _, targets_u0 = strongTransform(strong_parameters, target=torch.cat( (labels[0].unsqueeze(0), targets_u_w[0].unsqueeze(0)))) strong_parameters["Mix"] = MixMask1 _, targets_u1 = strongTransform(strong_parameters, target=torch.cat( (labels[1].unsqueeze(0), targets_u_w[1].unsqueeze(0)))) targets_u = torch.cat((targets_u0, targets_u1)).long() if pixel_weight == "threshold_uniform": unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(targets_u.cpu())) pixelWiseWeight = unlabeled_weight * torch.ones( max_probs.shape).cuda() elif pixel_weight == "threshold": pixelWiseWeight = max_probs.ge(0.968).float().cuda() elif pixel_weight == False: pixelWiseWeight = torch.ones(max_probs.shape).cuda() onesWeights = torch.ones((pixelWiseWeight.shape)).cuda() strong_parameters["Mix"] = MixMask0 _, pixelWiseWeight0 = strongTransform( strong_parameters, target=torch.cat((onesWeights[0].unsqueeze(0), pixelWiseWeight[0].unsqueeze(0)))) strong_parameters["Mix"] = MixMask1 _, pixelWiseWeight1 = strongTransform( strong_parameters, target=torch.cat((onesWeights[1].unsqueeze(0), pixelWiseWeight[1].unsqueeze(0)))) pixelWiseWeight = torch.cat( (pixelWiseWeight0, pixelWiseWeight1)).cuda() if consistency_loss == 'MSE': unlabeled_weight = torch.sum( max_probs.ge(0.968).long() == 1).item() / np.size( np.array(targets_u.cpu())) #pseudo_label = torch.cat((pseudo_label[1].unsqueeze(0),pseudo_label[0].unsqueeze(0))) L_u = consistency_weight * unlabeled_weight * unlabeled_loss( logits_u_s, pseudo_label) elif consistency_loss == 'CE': L_u = consistency_weight * unlabeled_loss( logits_u_s, targets_u, pixelWiseWeight) loss = L_l + L_u else: loss = L_l if len(gpus) > 1: #print('before mean = ',loss) loss = loss.mean() #print('after mean = ',loss) loss_l_value += L_l.mean().item() if train_unlabeled: loss_u_value += L_u.mean().item() else: loss_l_value += L_l.item() if train_unlabeled: loss_u_value += L_u.item() loss.backward() optimizer.step() # update Mean teacher network if ema_model is not None: alpha_teacher = 0.99 ema_model = update_ema_variables(ema_model=ema_model, model=model, alpha_teacher=alpha_teacher, iteration=i_iter) print( 'iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.format( i_iter, num_iterations, loss_l_value, loss_u_value)) if i_iter % save_checkpoint_every == 0 and i_iter != 0: if epochs_since_start * len(trainloader) < save_checkpoint_every: _save_checkpoint(i_iter, model, optimizer, config, ema_model, overwrite=False) else: _save_checkpoint(i_iter, model, optimizer, config, ema_model) if config['utils']['tensorboard']: if 'tensorboard_writer' not in locals(): tensorboard_writer = tensorboard.SummaryWriter(log_dir, flush_secs=30) accumulated_loss_l.append(loss_l_value) if train_unlabeled: accumulated_loss_u.append(loss_u_value) if i_iter % log_per_iter == 0 and i_iter != 0: tensorboard_writer.add_scalar('Training/Supervised loss', np.mean(accumulated_loss_l), i_iter) accumulated_loss_l = [] if train_unlabeled: tensorboard_writer.add_scalar('Training/Unsupervised loss', np.mean(accumulated_loss_u), i_iter) accumulated_loss_u = [] if i_iter % val_per_iter == 0 and i_iter != 0: model.eval() if dataset == 'cityscapes': mIoU, eval_loss = evaluate(model, dataset, ignore_label=250, input_size=(512, 1024), save_dir=checkpoint_dir) elif dataset == 'multiview': mIoU, eval_loss = evaluate(model, dataset, ignore_label=255, input_size=(300, 300), save_dir=checkpoint_dir) else: print('erro dataset: {}'.format(dataset)) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if config['utils']['tensorboard']: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', eval_loss, i_iter) print('iter {}, mIoU: {}'.format(mIoU, i_iter)) if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0: # Saves two mixed images and the corresponding prediction save_image(inputs_u_s[0].cpu(), i_iter, 'input1', palette.CityScpates_palette) save_image(inputs_u_s[1].cpu(), i_iter, 'input2', palette.CityScpates_palette) _, pred_u_s = torch.max(logits_u_s, dim=1) save_image(pred_u_s[0].cpu(), i_iter, 'pred1', palette.CityScpates_palette) save_image(pred_u_s[1].cpu(), i_iter, 'pred2', palette.CityScpates_palette) _save_checkpoint(num_iterations, model, optimizer, config, ema_model) model.eval() if dataset == 'cityscapes': mIoU, val_loss = evaluate(model, dataset, ignore_label=250, input_size=(512, 1024), save_dir=checkpoint_dir) elif dataset == 'multiview': mIoU, val_loss = evaluate(model, dataset, ignore_label=255, input_size=(300, 300), save_dir=checkpoint_dir) else: print('erro dataset: {}'.format(dataset)) model.train() if mIoU > best_mIoU and save_best_model: best_mIoU = mIoU _save_checkpoint(i_iter, model, optimizer, config, ema_model, save_best=True) if config['utils']['tensorboard']: tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter) tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter) end = timeit.default_timer() print('Total time: ' + str(end - start) + 'seconds')