def train_seg_semisup_vat_mt( submit_config: job_helper.SubmitConfig, dataset, model, arch, freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay, learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power, teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform, aug_rot_mag, vat_radius, adaptive_vat_radius, vat_dir_from_student, cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed, save_preds, save_model, num_workers): settings = locals().copy() del settings['submit_config'] import os import time import itertools import math import numpy as np import torch, torch.nn as nn, torch.nn.functional as F from architectures import network_architectures import torch.utils.data from datapipe import datasets from datapipe import seg_data, seg_transforms, seg_transforms_cv import evaluation import optim_weight_ema import lr_schedules if crop_size == '': crop_size = None else: crop_size = [int(x.strip()) for x in crop_size.split(',')] torch_device = torch.device('cuda:0') # # Load data sets # ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup, split_seed, split_path) ds_src = ds_dict['ds_src'] ds_tgt = ds_dict['ds_tgt'] tgt_val_ndx = ds_dict['val_ndx_tgt'] src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None test_ndx = ds_dict['test_ndx_tgt'] sup_ndx = ds_dict['sup_ndx'] unsup_ndx = ds_dict['unsup_ndx'] n_classes = ds_src.num_classes root_n_classes = math.sqrt(n_classes) if bin_fill_holes and n_classes != 2: print( 'Binary hole filling can only be used with binary (2-class) segmentation datasets' ) return print('Loaded data') # Build network NetClass = network_architectures.seg.get(arch) student_net = NetClass(ds_src.num_classes).to(torch_device) if opt_type == 'adam': student_optim = torch.optim.Adam([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ]) elif opt_type == 'sgd': student_optim = torch.optim.SGD([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ], momentum=sgd_momentum, nesterov=sgd_nesterov, weight_decay=sgd_weight_decay) else: raise ValueError('Unknown opt_type {}'.format(opt_type)) if model == 'mean_teacher': teacher_net = NetClass(ds_src.num_classes).to(torch_device) for p in teacher_net.parameters(): p.requires_grad = False teacher_optim = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, teacher_alpha) eval_net = teacher_net elif model == 'pi': teacher_net = student_net teacher_optim = None eval_net = student_net else: print('Unknown model type {}'.format(model)) return if vat_dir_from_student: vat_dir_net = student_net else: vat_dir_net = teacher_net BLOCK_SIZE = student_net.BLOCK_SIZE NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net) if freeze_bn: if not hasattr(student_net, 'freeze_batchnorm'): raise ValueError( 'Network {} does not support batchnorm freezing'.format(arch)) clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255) print('Built network') if iters_per_epoch == -1: iters_per_epoch = len(unsup_ndx) // batch_size total_iters = iters_per_epoch * num_epochs lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers( optimizer=student_optim, total_iters=total_iters, schedule_type=lr_sched, step_epochs=lr_step_epochs, step_gamma=lr_step_gamma, poly_power=lr_poly_power) # Train data pipeline: transforms train_transforms = [] if crop_size is not None: if aug_scale_hung: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropScaleHung( crop_size, (0, 0), uniform_scale=not aug_scale_non_uniform)) elif aug_max_scale != 1.0 or aug_rot_mag != 0.0: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropRotateScale( crop_size, (0, 0), rot_mag=aug_rot_mag, max_scale=aug_max_scale, uniform_scale=not aug_scale_non_uniform, constrain_rot_scale=True)) else: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0))) else: if aug_scale_hung: raise NotImplementedError('aug_scale_hung requires a crop_size') if aug_hflip or aug_vflip or aug_hvflip: train_transforms.append( seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip, aug_hvflip)) train_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) # Train data pipeline: supervised and unsupervised data sets train_sup_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') train_unsup_ds = ds_src.dataset( labels=False, mask=True, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') collate_fn = seg_data.SegCollate(BLOCK_SIZE) # Train data pipeline: data loaders sup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(sup_ndx)) train_sup_loader = torch.utils.data.DataLoader(train_sup_ds, batch_size, sampler=sup_sampler, collate_fn=collate_fn, num_workers=num_workers) if cons_weight > 0.0: unsup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(unsup_ndx)) train_unsup_loader = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=collate_fn, num_workers=num_workers) else: train_unsup_loader = None # Eval pipeline src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline( ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size, collate_fn, NET_MEAN, NET_STD, num_workers) # Report setttings print('Settings:') print(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ])) # Report dataset size print('Dataset:') print('len(sup_ndx)={}'.format(len(sup_ndx))) print('len(unsup_ndx)={}'.format(len(unsup_ndx))) if ds_src is not ds_tgt: print('len(src_val_ndx)={}'.format(len(tgt_val_ndx))) print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx))) else: print('len(val_ndx)={}'.format(len(tgt_val_ndx))) if test_ndx is not None: print('len(test_ndx)={}'.format(len(test_ndx))) if n_sup != -1: print('sup_ndx={}'.format(sup_ndx.tolist())) def t_dot(a, b): return (a * b).sum(dim=1, keepdim=True) def normalize_eps(x): x_flat = x.view(len(x), -1) mag = torch.sqrt((x_flat * x_flat).sum(dim=1)) return x / (mag[:, None, None, None] + 1e-12) def normalized_noise(x, requires_grad=False, scale=1.0): eps = torch.randn(x.shape, dtype=torch.float, device=x.device) eps = normalize_eps(eps) * scale if requires_grad: eps = eps.clone().detach().requires_grad_(True) return eps def vat_direction(x): """ Compute the VAT perturbation direction vector :param x: input image as a `(N, C, H, W)` tensor :return: VAT direction as a `(N, C, H, W)` tensor """ # Put the network used to get the VAT direction in eval mode and get the predicted # logits and probabilities for the batch of samples x vat_dir_net.eval() with torch.no_grad(): y_pred_logits = vat_dir_net(x).detach() y_pred_prob = F.softmax(y_pred_logits, dim=1) # Initial noise offset vector with requires_grad=True noise_scale = 1.0e-6 * x.shape[2] * x.shape[3] / 1000 eps = normalized_noise(x, requires_grad=True, scale=noise_scale) # Predict logits and probs for sample perturbed by eps eps_pred_logits = vat_dir_net(x.detach() + eps) eps_pred_prob = F.softmax(eps_pred_logits, dim=1) # Choose our loss function if cons_loss_fn == 'var': delta = (eps_pred_prob - y_pred_prob) loss = (delta * delta).sum() elif cons_loss_fn == 'bce': loss = network_architectures.robust_binary_crossentropy( eps_pred_prob, y_pred_prob).sum() elif cons_loss_fn == 'kld': loss = F.kl_div(F.log_softmax(eps_pred_logits, dim=1), y_pred_prob, reduce=False).sum() elif cons_loss_fn == 'logits_var': delta = (eps_pred_logits - y_pred_logits) loss = (delta * delta).sum() else: raise ValueError( 'Unknown consistency loss function {}'.format(cons_loss_fn)) # Differentiate the loss w.r.t. the perturbation eps_adv = torch.autograd.grad(outputs=loss, inputs=eps, create_graph=True, retain_graph=True, only_inputs=True)[0] # Normalize the adversarial perturbation return normalize_eps(eps_adv), y_pred_logits, y_pred_prob def vat_perburbation(x, m): eps_adv_nrm, y_pred_logits, y_pred_prob = vat_direction(x) if adaptive_vat_radius: # We view semantic segmentation as predicting the class of a pixel # given a patch centred on that pixel. # The most similar patch in terms of pixel content to a patch P # is a patch Q whose central pixel is an immediate neighbour # of the central pixel P. # We therefore use the image Jacobian (gradient w.r.t. x and y) to # get a sense of the distance between neighbouring patches # so we can scale the VAT radius according to the image content. # Delta in vertical and horizontal directions delta_v = x[:, :, 2:, :] - x[:, :, :-2, :] delta_h = x[:, :, :, 2:] - x[:, :, :, :-2] # delta_h and delta_v are the difference between pixels where the step size is 2, rather than 1 # So divide by 2 to get the magnitude of the Jacobian delta_v = delta_v.view(len(delta_v), -1) delta_h = delta_h.view(len(delta_h), -1) adv_radius = vat_radius * torch.sqrt( (delta_v**2).sum(dim=1) + (delta_h**2).sum(dim=1))[:, None, None, None] * 0.5 else: scale = math.sqrt(float(x.shape[1] * x.shape[2] * x.shape[3])) adv_radius = vat_radius * scale return (eps_adv_nrm * adv_radius).detach(), y_pred_logits, y_pred_prob # Track mIoU for early stopping best_tgt_miou = None best_epoch = 0 eval_net_state = { key: value.detach().cpu().numpy() for key, value in eval_net.state_dict().items() } # Create iterators train_sup_iter = iter(train_sup_loader) train_unsup_iter = iter( train_unsup_loader) if train_unsup_loader is not None else None iter_i = 0 print('Training...') for epoch_i in range(num_epochs): if lr_epoch_scheduler is not None: lr_epoch_scheduler.step(epoch_i) t1 = time.time() if rampup > 0: ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup) else: ramp_val = 1.0 student_net.train() if teacher_net is not student_net: teacher_net.train() if freeze_bn: student_net.freeze_batchnorm() if teacher_net is not student_net: teacher_net.freeze_batchnorm() sup_loss_acc = 0.0 consistency_loss_acc = 0.0 conf_rate_acc = 0.0 n_sup_batches = 0 n_unsup_batches = 0 src_val_iter = iter( src_val_loader) if src_val_loader is not None else None tgt_val_iter = iter( tgt_val_loader) if tgt_val_loader is not None else None for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch): if lr_iter_scheduler is not None: lr_iter_scheduler.step(iter_i) student_optim.zero_grad() # # Supervised branch # batch_x = sup_batch['image'].to(torch_device) batch_y = sup_batch['labels'].to(torch_device) logits_sup = student_net(batch_x) sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :]) sup_loss.backward() if cons_weight > 0.0: for _ in range(unsup_batch_ratio): # # Unsupervised branch # unsup_batch = next(train_unsup_iter) # Input images to torch tensor batch_ux = unsup_batch['image'].to(torch_device) batch_um = unsup_batch['mask'].to(torch_device) # batch_um is a mask that is 1 for valid pixels, 0 for invalid pixels. # It us used later on to scale the consistency loss, so that consistency loss is # only computed for valid pixels. # Explanation: # When using geometric augmentations such as rotations, some pixels in the training # crop may come from outside the bounds of the input image. These pixels will have a value # of 0 in these masks. Similarly, when using scaled crops, the size of the crop # from the input image that must be scaled to the size of the training crop may be # larger than one/both of the input image dimensions. Pixels in the training crop # that arise from outside the input image bounds will once again be given a value # of 0 in these masks. # Compute VAT perburbation x_perturb, logits_cons_tea, prob_cons_tea = vat_perburbation( batch_ux, batch_um) # Perturb image batch_ux_adv = batch_ux + x_perturb # Get teacher predictions for original image with torch.no_grad(): logits_cons_tea = teacher_net(batch_ux).detach() # Get student prediction for cut image logits_cons_stu = student_net(batch_ux_adv) # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) loss_mask = batch_um # Confidence thresholding if conf_thresh > 0.0: # Compute confidence of teacher predictions conf_tea = prob_cons_tea.max(dim=1)[0] # Compute confidence mask conf_mask = (conf_tea >= conf_thresh).float()[:, None, :, :] # Record rate for reporting conf_rate_acc += float(conf_mask.mean()) # Average confidence mask if requested if not conf_per_pixel: conf_mask = conf_mask.mean() loss_mask = loss_mask * conf_mask elif rampup > 0: conf_rate_acc += ramp_val # Compute per-pixel consistency loss # Note that the way we aggregate the loss across the class/channel dimension (1) # depends on the loss function used. Generally, summing over the class dimension # keeps the magnitude of the gradient of the loss w.r.t. the logits # nearly constant w.r.t. the number of classes. When using logit-variance, # dividing by `sqrt(num_classes)` helps. if cons_loss_fn == 'var': delta_prob = prob_cons_stu - prob_cons_tea consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'logits_var': delta_logits = logits_cons_stu - logits_cons_tea consistency_loss = delta_logits * delta_logits consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'bce': consistency_loss = network_architectures.robust_binary_crossentropy( prob_cons_stu, prob_cons_tea) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'kld': consistency_loss = F.kl_div(F.log_softmax( logits_cons_stu, dim=1), prob_cons_tea, reduce=False) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) else: raise ValueError( 'Unknown consistency loss function {}'.format( cons_loss_fn)) # Apply consistency loss mask and take the mean over pixels and images consistency_loss = (consistency_loss * loss_mask).mean() # Modulate with rampup if desired if rampup > 0: consistency_loss = consistency_loss * ramp_val # Weight the consistency loss and back-prop unsup_loss = consistency_loss * cons_weight unsup_loss.backward() consistency_loss_val = float(consistency_loss.detach()) consistency_loss_acc += consistency_loss_val if np.isnan(consistency_loss_val): print( 'NaN detected in consistency loss; bailing out...') return n_unsup_batches += 1 student_optim.step() if teacher_optim is not None: teacher_optim.step() sup_loss_val = float(sup_loss.detach()) sup_loss_acc += sup_loss_val if np.isnan(sup_loss_val): print('NaN detected in supervised loss; bailing out...') return n_sup_batches += 1 iter_i += 1 sup_loss_acc /= n_sup_batches if n_unsup_batches > 0: consistency_loss_acc /= n_unsup_batches conf_rate_acc /= n_unsup_batches eval_net.eval() if src_val_iter is not None: src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes, bin_fill_holes) with torch.no_grad(): for batch in src_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): src_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) src_iou = src_iou_eval.score() src_miou = src_iou.mean() else: src_iou_eval = src_iou = src_miou = None tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in tgt_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): tgt_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) tgt_iou = tgt_iou_eval.score() tgt_miou = tgt_iou.mean() t2 = time.time() if ds_src is not ds_tgt: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, ' 'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format( epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, src_miou, tgt_miou)) print('-- SRC {}'.format(', '.join( ['{:.3%}'.format(x) for x in src_iou]))) print('-- TGT {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) else: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}' .format(epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, tgt_miou)) print('-- {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) if save_model: model_path = os.path.join(submit_config.run_dir, "model.pth") torch.save(eval_net, model_path) if save_preds: out_dir = os.path.join(submit_config.run_dir, 'preds') os.makedirs(out_dir, exist_ok=True) with torch.no_grad(): for batch in tgt_val_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) else: out_dir = None if test_loader is not None: test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in test_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): if save_preds: ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) test_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) test_iou = test_iou_eval.score() test_miou = test_iou.mean() print('FINAL TEST: mIoU={:.3%}'.format(test_miou)) print('-- TEST {}'.format(', '.join( ['{:.3%}'.format(x) for x in test_iou])))
def train_seg_semisup_aug_mt( submit_config: job_helper.SubmitConfig, dataset, model, arch, freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay, learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power, teacher_alpha, bin_fill_holes, crop_size, aug_offset_range, aug_hflip, aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform, aug_rot_mag, aug_free_scale_rot, cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed, save_preds, save_model, num_workers): settings = locals().copy() del settings['submit_config'] import os import math import time import itertools import numpy as np import torch.nn as nn, torch.nn.functional as F from architectures import network_architectures import torch.utils.data from datapipe import datasets from datapipe import seg_data, seg_transforms, seg_transforms_cv import evaluation import optim_weight_ema import lr_schedules from datapipe import torch_utils affine_align_corners_kw = torch_utils.affine_align_corners_kw(True) if crop_size == '': crop_size = None else: crop_size = [int(x.strip()) for x in crop_size.split(',')] torch_device = torch.device('cuda:0') # # Load data sets # ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup, split_seed, split_path) ds_src = ds_dict['ds_src'] ds_tgt = ds_dict['ds_tgt'] tgt_val_ndx = ds_dict['val_ndx_tgt'] src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None test_ndx = ds_dict['test_ndx_tgt'] sup_ndx = ds_dict['sup_ndx'] unsup_ndx = ds_dict['unsup_ndx'] n_classes = ds_src.num_classes root_n_classes = math.sqrt(n_classes) if bin_fill_holes and n_classes != 2: print( 'Binary hole filling can only be used with binary (2-class) segmentation datasets' ) return print('Loaded data') # Build network NetClass = network_architectures.seg.get(arch) student_net = NetClass(ds_src.num_classes).to(torch_device) if opt_type == 'adam': student_optim = torch.optim.Adam([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ]) elif opt_type == 'sgd': student_optim = torch.optim.SGD([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ], momentum=sgd_momentum, nesterov=sgd_nesterov, weight_decay=sgd_weight_decay) else: raise ValueError('Unknown opt_type {}'.format(opt_type)) if model == 'mean_teacher': teacher_net = NetClass(ds_src.num_classes).to(torch_device) for p in teacher_net.parameters(): p.requires_grad = False teacher_optim = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, teacher_alpha) eval_net = teacher_net elif model == 'pi': teacher_net = student_net teacher_optim = None eval_net = student_net else: print('Unknown model type {}'.format(model)) return BLOCK_SIZE = student_net.BLOCK_SIZE NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net) if freeze_bn: if not hasattr(student_net, 'freeze_batchnorm'): raise ValueError( 'Network {} does not support batchnorm freezing'.format(arch)) clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255) print('Built network') if iters_per_epoch == -1: iters_per_epoch = len(unsup_ndx) // batch_size total_iters = iters_per_epoch * num_epochs lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers( optimizer=student_optim, total_iters=total_iters, schedule_type=lr_sched, step_epochs=lr_step_epochs, step_gamma=lr_step_gamma, poly_power=lr_poly_power) # Train data pipeline: transforms train_transforms = [] if crop_size is not None: if aug_scale_hung: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropScaleHung( crop_size, (aug_offset_range, aug_offset_range), uniform_scale=not aug_scale_non_uniform)) elif aug_max_scale != 1.0 or aug_rot_mag != 0.0: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropRotateScale( crop_size, (aug_offset_range, aug_offset_range), rot_mag=aug_rot_mag, max_scale=aug_max_scale, uniform_scale=not aug_scale_non_uniform, constrain_rot_scale=not aug_free_scale_rot)) else: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCrop( crop_size, (aug_offset_range, aug_offset_range))) else: if aug_scale_hung: raise NotImplementedError('aug_scale_hung requires a crop_size') if aug_hflip or aug_vflip or aug_hvflip: train_transforms.append( seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip, aug_hvflip)) train_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) # Train data pipeline: supervised and unsupervised data sets train_sup_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') train_unsup_ds = ds_src.dataset( labels=False, mask=True, xf=True, pair=True, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') collate_fn = seg_data.SegCollate(BLOCK_SIZE) # Train data pipeline: data loaders sup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(sup_ndx)) train_sup_loader = torch.utils.data.DataLoader(train_sup_ds, batch_size, sampler=sup_sampler, collate_fn=collate_fn, num_workers=num_workers) if cons_weight > 0.0: unsup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(unsup_ndx)) train_unsup_loader = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=collate_fn, num_workers=num_workers) else: train_unsup_loader = None # Eval pipeline src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline( ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size, collate_fn, NET_MEAN, NET_STD, num_workers) # Report setttings print('Settings:') print(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ])) # Report dataset size print('Dataset:') print('len(sup_ndx)={}'.format(len(sup_ndx))) print('len(unsup_ndx)={}'.format(len(unsup_ndx))) if ds_src is not ds_tgt: print('len(src_val_ndx)={}'.format(len(tgt_val_ndx))) print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx))) else: print('len(val_ndx)={}'.format(len(tgt_val_ndx))) if test_ndx is not None: print('len(test_ndx)={}'.format(len(test_ndx))) if n_sup != -1: print('sup_ndx={}'.format(sup_ndx.tolist())) # Track mIoU for early stopping best_tgt_miou = None best_epoch = 0 eval_net_state = { key: value.detach().cpu().numpy() for key, value in eval_net.state_dict().items() } # Create iterators train_sup_iter = iter(train_sup_loader) train_unsup_iter = iter( train_unsup_loader) if train_unsup_loader is not None else None iter_i = 0 print('Training...') for epoch_i in range(num_epochs): if lr_epoch_scheduler is not None: lr_epoch_scheduler.step(epoch_i) t1 = time.time() if rampup > 0: ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup) else: ramp_val = 1.0 student_net.train() if teacher_net is not student_net: teacher_net.train() if freeze_bn: student_net.freeze_batchnorm() if teacher_net is not student_net: teacher_net.freeze_batchnorm() sup_loss_acc = 0.0 consistency_loss_acc = 0.0 conf_rate_acc = 0.0 n_sup_batches = 0 n_unsup_batches = 0 src_val_iter = iter( src_val_loader) if src_val_loader is not None else None tgt_val_iter = iter( tgt_val_loader) if tgt_val_loader is not None else None for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch): if lr_iter_scheduler is not None: lr_iter_scheduler.step(iter_i) student_optim.zero_grad() # # Supervised branch # batch_x = sup_batch['image'].to(torch_device) batch_y = sup_batch['labels'].to(torch_device) logits_sup = student_net(batch_x) sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :]) sup_loss.backward() if cons_weight > 0.0: for _ in range(unsup_batch_ratio): # # Unsupervised branch # # Cut mode: batch consists of unsupervised samples and mask params unsup_batch = next(train_unsup_iter) # Input images to torch tensor batch_ux0 = unsup_batch['sample0']['image'].to( torch_device) batch_um0 = unsup_batch['sample0']['mask'].to(torch_device) batch_ux1 = unsup_batch['sample1']['image'].to( torch_device) batch_um1 = unsup_batch['sample1']['mask'].to(torch_device) batch_ufx0_to_1 = unsup_batch['xf0_to_1'].to(torch_device) # Get teacher predictions for image0 with torch.no_grad(): logits_cons_tea = teacher_net(batch_ux0).detach() # Get student prediction for image1 logits_cons_stu = student_net(batch_ux1) # Transformation from teacher to student space grid_tea_to_stu = F.affine_grid(batch_ufx0_to_1, batch_ux0.shape, **affine_align_corners_kw) # Transform teacher predicted logits to student space logits_cons_tea_in_stu = F.grid_sample( logits_cons_tea, grid_tea_to_stu, **affine_align_corners_kw) # Transform mask from teacher to student space and multiply by student space mask mask_tea_in_stu = F.grid_sample( batch_um0, grid_tea_to_stu, ** affine_align_corners_kw) * batch_um1 # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) # Transform teacher predicted probabilities to student space prob_cons_tea_in_stu = F.grid_sample( prob_cons_tea, grid_tea_to_stu, **affine_align_corners_kw) # for i in range(len(batch_ux0)): # plt.figure(figsize=(18, 12)) # # x_0_in_1 = F.grid_sample(batch_ux0, grid_tea_to_stu) # d_x0_in_1 = torch.abs(x_0_in_1 - batch_ux1) * mask_tea_in_stu # mask_tea_in_stu_np = mask_tea_in_stu.detach().cpu().numpy() # # plt.subplot(2, 4, 1) # plt.imshow(batch_ux0[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5) # plt.subplot(2, 4, 2) # plt.imshow(batch_ux1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5) # plt.subplot(2, 4, 3) # plt.imshow(x_0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5) # plt.subplot(2, 4, 4) # plt.imshow(d_x0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 10 + 0.5, cmap='gray') # # plt.subplot(2, 4, 5) # plt.imshow(batch_um0[i,0].detach().cpu().numpy(), cmap='gray') # plt.subplot(2, 4, 6) # plt.imshow(batch_um1[i,0].detach().cpu().numpy(), cmap='gray') # plt.subplot(2, 4, 7) # plt.imshow(mask_tea_in_stu[i,0].detach().cpu().numpy(), cmap='gray') # # plt.show() loss_mask = mask_tea_in_stu # Confidence thresholding if conf_thresh > 0.0: # Compute confidence of teacher predictions conf_tea = prob_cons_tea_in_stu.max(dim=1)[0] # Compute confidence mask conf_mask = (conf_tea >= conf_thresh).float()[:, None, :, :] # Record rate for reporting conf_rate_acc += float(conf_mask.mean()) # Average confidence mask if requested if not conf_per_pixel: conf_mask = conf_mask.mean() loss_mask = loss_mask * conf_mask elif rampup > 0: conf_rate_acc += ramp_val # Compute per-pixel consistency loss # Note that the way we aggregate the loss across the class/channel dimension (1) # depends on the loss function used. Generally, summing over the class dimension # keeps the magnitude of the gradient of the loss w.r.t. the logits # nearly constant w.r.t. the number of classes. When using logit-variance, # dividing by `sqrt(num_classes)` helps. if cons_loss_fn == 'var': delta_prob = prob_cons_stu - prob_cons_tea_in_stu consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'logits_var': delta_logits = logits_cons_stu - logits_cons_tea_in_stu consistency_loss = delta_logits * delta_logits consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'logits_smoothl1': consistency_loss = F.smooth_l1_loss( logits_cons_stu, logits_cons_tea_in_stu, reduce=False) consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'bce': consistency_loss = network_architectures.robust_binary_crossentropy( prob_cons_stu, prob_cons_tea_in_stu) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'kld': consistency_loss = F.kl_div(F.log_softmax( logits_cons_stu, dim=1), prob_cons_tea_in_stu, reduce=False) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) else: raise ValueError( 'Unknown consistency loss function {}'.format( cons_loss_fn)) # Apply consistency loss mask and take the mean over pixels and images consistency_loss = (consistency_loss * loss_mask).mean() # Modulate with rampup if desired if rampup > 0: consistency_loss = consistency_loss * ramp_val # Weight the consistency loss and back-prop unsup_loss = consistency_loss * cons_weight unsup_loss.backward() consistency_loss_acc += float(consistency_loss.detach()) n_unsup_batches += 1 student_optim.step() if teacher_optim is not None: teacher_optim.step() sup_loss_acc += float(sup_loss.detach()) n_sup_batches += 1 iter_i += 1 sup_loss_acc /= n_sup_batches if n_unsup_batches > 0: consistency_loss_acc /= n_unsup_batches conf_rate_acc /= n_unsup_batches eval_net.eval() if src_val_iter is not None: src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes, bin_fill_holes) with torch.no_grad(): for batch in src_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): src_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) src_iou = src_iou_eval.score() src_miou = src_iou.mean() else: src_iou_eval = src_iou = src_miou = None tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in tgt_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): tgt_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) tgt_iou = tgt_iou_eval.score() tgt_miou = tgt_iou.mean() t2 = time.time() if ds_src is not ds_tgt: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, ' 'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format( epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, src_miou, tgt_miou)) print('-- SRC {}'.format(', '.join( ['{:.3%}'.format(x) for x in src_iou]))) print('-- TGT {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) else: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}' .format(epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, tgt_miou)) print('-- {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) if save_model: model_path = os.path.join(submit_config.run_dir, "model.pth") torch.save(eval_net, model_path) if save_preds: out_dir = os.path.join(submit_config.run_dir, 'preds') os.makedirs(out_dir, exist_ok=True) with torch.no_grad(): for batch in tgt_val_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) else: out_dir = None if test_loader is not None: test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in test_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): if save_preds: ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) test_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) test_iou = test_iou_eval.score() test_miou = test_iou.mean() print('FINAL TEST: mIoU={:.3%}'.format(test_miou)) print('-- TEST {}'.format(', '.join( ['{:.3%}'.format(x) for x in test_iou])))
def train_seg_semisup_mask_mt( submit_config: job_helper.SubmitConfig, dataset, model, arch, freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay, learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power, teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform, aug_rot_mag, mask_mode, mask_prop_range, boxmask_n_boxes, boxmask_fixed_aspect_ratio, boxmask_by_size, boxmask_outside_bounds, boxmask_no_invert, cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed, save_preds, save_model, num_workers): settings = locals().copy() del settings['submit_config'] if ':' in mask_prop_range: a, b = mask_prop_range.split(':') mask_prop_range = (float(a.strip()), float(b.strip())) del a, b else: mask_prop_range = float(mask_prop_range) if mask_mode == 'zero': mask_mix = False elif mask_mode == 'mix': mask_mix = True else: raise ValueError('Unknown mask_mode {}'.format(mask_mode)) del mask_mode import os import math import time import itertools import numpy as np import torch.nn as nn, torch.nn.functional as F from architectures import network_architectures import torch.utils.data from datapipe import datasets from datapipe import seg_data, seg_transforms, seg_transforms_cv import evaluation import optim_weight_ema import mask_gen import lr_schedules if crop_size == '': crop_size = None else: crop_size = [int(x.strip()) for x in crop_size.split(',')] torch_device = torch.device('cuda:0') # # Load data sets # ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup, split_seed, split_path) ds_src = ds_dict['ds_src'] ds_tgt = ds_dict['ds_tgt'] tgt_val_ndx = ds_dict['val_ndx_tgt'] src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None test_ndx = ds_dict['test_ndx_tgt'] sup_ndx = ds_dict['sup_ndx'] unsup_ndx = ds_dict['unsup_ndx'] n_classes = ds_src.num_classes root_n_classes = math.sqrt(n_classes) if bin_fill_holes and n_classes != 2: print( 'Binary hole filling can only be used with binary (2-class) segmentation datasets' ) return print('Loaded data') # Build network NetClass = network_architectures.seg.get(arch) student_net = NetClass(ds_src.num_classes).to(torch_device) if opt_type == 'adam': student_optim = torch.optim.Adam([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ]) elif opt_type == 'sgd': student_optim = torch.optim.SGD([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ], momentum=sgd_momentum, nesterov=sgd_nesterov, weight_decay=sgd_weight_decay) else: raise ValueError('Unknown opt_type {}'.format(opt_type)) if model == 'mean_teacher': teacher_net = NetClass(ds_src.num_classes).to(torch_device) for p in teacher_net.parameters(): p.requires_grad = False teacher_optim = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, teacher_alpha) eval_net = teacher_net elif model == 'pi': teacher_net = student_net teacher_optim = None eval_net = student_net else: print('Unknown model type {}'.format(model)) return BLOCK_SIZE = student_net.BLOCK_SIZE NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net) if freeze_bn: if not hasattr(student_net, 'freeze_batchnorm'): raise ValueError( 'Network {} does not support batchnorm freezing'.format(arch)) clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255) print('Built network') mask_generator = mask_gen.BoxMaskGenerator( prop_range=mask_prop_range, n_boxes=boxmask_n_boxes, random_aspect_ratio=not boxmask_fixed_aspect_ratio, prop_by_area=not boxmask_by_size, within_bounds=not boxmask_outside_bounds, invert=not boxmask_no_invert) if iters_per_epoch == -1: iters_per_epoch = len(unsup_ndx) // batch_size total_iters = iters_per_epoch * num_epochs lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers( optimizer=student_optim, total_iters=total_iters, schedule_type=lr_sched, step_epochs=lr_step_epochs, step_gamma=lr_step_gamma, poly_power=lr_poly_power) train_transforms = [] eval_transforms = [] if crop_size is not None: if aug_scale_hung: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropScaleHung( crop_size, (0, 0), uniform_scale=not aug_scale_non_uniform)) elif aug_max_scale != 1.0 or aug_rot_mag != 0.0: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropRotateScale( crop_size, (0, 0), rot_mag=aug_rot_mag, max_scale=aug_max_scale, uniform_scale=not aug_scale_non_uniform, constrain_rot_scale=True)) else: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0))) else: if aug_scale_hung: raise NotImplementedError('aug_scale_hung requires a crop_size') if aug_hflip or aug_vflip or aug_hvflip: train_transforms.append( seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip, aug_hvflip)) train_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) eval_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) train_sup_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') train_unsup_ds = ds_src.dataset( labels=False, mask=True, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') eval_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(eval_transforms), pipeline_type='cv') add_mask_params_to_batch = mask_gen.AddMaskParamsToBatch(mask_generator) collate_fn = seg_data.SegCollate(BLOCK_SIZE) mask_collate_fn = seg_data.SegCollate( BLOCK_SIZE, batch_aug_fn=add_mask_params_to_batch) # Train data pipeline: data loaders sup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(sup_ndx)) train_sup_loader = torch.utils.data.DataLoader(train_sup_ds, batch_size, sampler=sup_sampler, collate_fn=collate_fn, num_workers=num_workers) if cons_weight > 0.0: unsup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(unsup_ndx)) train_unsup_loader_0 = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=mask_collate_fn, num_workers=num_workers) if mask_mix: train_unsup_loader_1 = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=collate_fn, num_workers=num_workers) else: train_unsup_loader_1 = None else: train_unsup_loader_0 = None train_unsup_loader_1 = None # Eval pipeline src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline( ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size, collate_fn, NET_MEAN, NET_STD, num_workers) # Report setttings print('Settings:') print(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ])) # Report dataset size print('Dataset:') print('len(sup_ndx)={}'.format(len(sup_ndx))) print('len(unsup_ndx)={}'.format(len(unsup_ndx))) if ds_src is not ds_tgt: print('len(src_val_ndx)={}'.format(len(tgt_val_ndx))) print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx))) else: print('len(val_ndx)={}'.format(len(tgt_val_ndx))) if test_ndx is not None: print('len(test_ndx)={}'.format(len(test_ndx))) if n_sup != -1: print('sup_ndx={}'.format(sup_ndx.tolist())) # Track mIoU for early stopping best_tgt_miou = None best_epoch = 0 eval_net_state = { key: value.detach().cpu().numpy() for key, value in eval_net.state_dict().items() } # Create iterators train_sup_iter = iter(train_sup_loader) train_unsup_iter_0 = iter( train_unsup_loader_0) if train_unsup_loader_0 is not None else None train_unsup_iter_1 = iter( train_unsup_loader_1) if train_unsup_loader_1 is not None else None iter_i = 0 print('Training...') for epoch_i in range(num_epochs): if lr_epoch_scheduler is not None: lr_epoch_scheduler.step(epoch_i) t1 = time.time() if rampup > 0: ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup) else: ramp_val = 1.0 student_net.train() if teacher_net is not student_net: teacher_net.train() if freeze_bn: student_net.freeze_batchnorm() if teacher_net is not student_net: teacher_net.freeze_batchnorm() sup_loss_acc = 0.0 consistency_loss_acc = 0.0 conf_rate_acc = 0.0 n_sup_batches = 0 n_unsup_batches = 0 src_val_iter = iter( src_val_loader) if src_val_loader is not None else None tgt_val_iter = iter( tgt_val_loader) if tgt_val_loader is not None else None for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch): if lr_iter_scheduler is not None: lr_iter_scheduler.step(iter_i) student_optim.zero_grad() # # Supervised branch # batch_x = sup_batch['image'].to(torch_device) batch_y = sup_batch['labels'].to(torch_device) logits_sup = student_net(batch_x) sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :]) sup_loss.backward() if cons_weight > 0.0: for _ in range(unsup_batch_ratio): # # Unsupervised branch # if mask_mix: # Mix mode: batch consists of paired unsupervised samples and mask parameters unsup_batch0 = next(train_unsup_iter_0) unsup_batch1 = next(train_unsup_iter_1) batch_ux0 = unsup_batch0['image'].to(torch_device) batch_um0 = unsup_batch0['mask'].to(torch_device) batch_ux1 = unsup_batch1['image'].to(torch_device) batch_um1 = unsup_batch1['mask'].to(torch_device) batch_mask_params = unsup_batch0['mask_params'].to( torch_device) # batch_um0 and batch_um1 are masks that are 1 for valid pixels, 0 for invalid pixels. # They are used later on to scale the consistency loss, so that consistency loss is # only computed for valid pixels. # Explanation: # When using geometric augmentations such as rotations, some pixels in the training # crop may come from outside the bounds of the input image. These pixels will have a value # of 0 in these masks. Similarly, when using scaled crops, the size of the crop # from the input image that must be scaled to the size of the training crop may be # larger than one/both of the input image dimensions. Pixels in the training crop # that arise from outside the input image bounds will once again be given a value # of 0 in these masks. # Convert mask parameters to masks of shape (N,1,H,W) batch_mix_masks = mask_generator.torch_masks_from_params( batch_mask_params, batch_ux0.shape[2:4], torch_device) # Mix images with masks batch_ux_mixed = batch_ux0 * ( 1 - batch_mix_masks) + batch_ux1 * batch_mix_masks batch_um_mixed = batch_um0 * ( 1 - batch_mix_masks) + batch_um1 * batch_mix_masks # Get teacher predictions for original images with torch.no_grad(): logits_u0_tea = teacher_net(batch_ux0).detach() logits_u1_tea = teacher_net(batch_ux1).detach() # Get student prediction for mixed image logits_cons_stu = student_net(batch_ux_mixed) # Mix teacher predictions using same mask # It makes no difference whether we do this with logits or probabilities as # the mask pixels are either 1 or 0 logits_cons_tea = logits_u0_tea * ( 1 - batch_mix_masks) + logits_u1_tea * batch_mix_masks # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) loss_mask = batch_um_mixed else: # Cut mode: batch consists of unsupervised samples and mask params unsup_batch = next(train_unsup_iter_0) batch_ux = unsup_batch['image'].to(torch_device) batch_um = unsup_batch['mask'].to(torch_device) batch_mask_params = unsup_batch['mask_params'].to( torch_device) # Convert mask parameters to masks of shape (N,1,H,W) batch_cut_masks = mask_generator.torch_masks_from_params( batch_mask_params, batch_ux.shape[2:4], torch_device) # Cut image with mask (mask regions to zero) batch_ux_cut = batch_ux * batch_cut_masks # Get teacher predictions for original image with torch.no_grad(): logits_cons_tea = teacher_net(batch_ux).detach() # Get student prediction for cut image logits_cons_stu = student_net(batch_ux_cut) # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) loss_mask = batch_cut_masks * batch_um # -- shared by mix and cut -- # Confidence thresholding if conf_thresh > 0.0: # Compute confidence of teacher predictions conf_tea = prob_cons_tea.max(dim=1)[0] # Compute confidence mask conf_mask = (conf_tea >= conf_thresh).float()[:, None, :, :] # Record rate for reporting conf_rate_acc += float(conf_mask.mean()) # Average confidence mask if requested if not conf_per_pixel: conf_mask = conf_mask.mean() loss_mask = loss_mask * conf_mask elif rampup > 0: conf_rate_acc += ramp_val # Compute per-pixel consistency loss # Note that the way we aggregate the loss across the class/channel dimension (1) # depends on the loss function used. Generally, summing over the class dimension # keeps the magnitude of the gradient of the loss w.r.t. the logits # nearly constant w.r.t. the number of classes. When using logit-variance, # dividing by `sqrt(num_classes)` helps. if cons_loss_fn == 'var': delta_prob = prob_cons_stu - prob_cons_tea consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'logits_var': delta_logits = logits_cons_stu - logits_cons_tea consistency_loss = delta_logits * delta_logits consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'logits_smoothl1': consistency_loss = F.smooth_l1_loss(logits_cons_stu, logits_cons_tea, reduce=False) consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'bce': consistency_loss = network_architectures.robust_binary_crossentropy( prob_cons_stu, prob_cons_tea) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'kld': consistency_loss = F.kl_div(F.log_softmax( logits_cons_stu, dim=1), prob_cons_tea, reduce=False) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) else: raise ValueError( 'Unknown consistency loss function {}'.format( cons_loss_fn)) # Apply consistency loss mask and take the mean over pixels and images consistency_loss = (consistency_loss * loss_mask).mean() # Modulate with rampup if desired if rampup > 0: consistency_loss = consistency_loss * ramp_val # Weight the consistency loss and back-prop unsup_loss = consistency_loss * cons_weight unsup_loss.backward() consistency_loss_acc += float(consistency_loss.detach()) n_unsup_batches += 1 student_optim.step() if teacher_optim is not None: teacher_optim.step() sup_loss_val = float(sup_loss.detach()) if np.isnan(sup_loss_val): print('NaN detected; network dead, bailing.') return sup_loss_acc += sup_loss_val n_sup_batches += 1 iter_i += 1 sup_loss_acc /= n_sup_batches if n_unsup_batches > 0: consistency_loss_acc /= n_unsup_batches conf_rate_acc /= n_unsup_batches eval_net.eval() if ds_src is not ds_tgt: src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes, bin_fill_holes) with torch.no_grad(): for batch in src_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): src_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) src_iou = src_iou_eval.score() src_miou = src_iou.mean() else: src_iou_eval = src_iou = src_miou = None tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in tgt_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): tgt_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) tgt_iou = tgt_iou_eval.score() tgt_miou = tgt_iou.mean() t2 = time.time() if ds_src is not ds_tgt: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, ' 'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format( epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, src_miou, tgt_miou)) print('-- SRC {}'.format(', '.join( ['{:.3%}'.format(x) for x in src_iou]))) print('-- TGT {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) else: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}' .format(epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, tgt_miou)) print('-- {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) if save_model: model_path = os.path.join(submit_config.run_dir, "model.pth") torch.save(eval_net, model_path) if save_preds: out_dir = os.path.join(submit_config.run_dir, 'preds') os.makedirs(out_dir, exist_ok=True) with torch.no_grad(): for batch in tgt_val_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) else: out_dir = None if test_loader is not None: test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in test_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): if save_preds: ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) test_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) test_iou = test_iou_eval.score() test_miou = test_iou.mean() print('FINAL TEST: mIoU={:.3%}'.format(test_miou)) print('-- TEST {}'.format(', '.join( ['{:.3%}'.format(x) for x in test_iou])))