def experiment(exp, arch, loss, double_softmax, confidence_thresh, rampup, teacher_alpha, fix_ema, unsup_weight, cls_bal_scale, cls_bal_scale_range, cls_balance, cls_balance_loss, combine_batches, learning_rate, standardise_samples, src_affine_std, src_xlat_range, src_hflip, src_intens_flip, src_intens_scale_range, src_intens_offset_range, src_gaussian_noise_std, tgt_affine_std, tgt_xlat_range, tgt_hflip, tgt_intens_flip, tgt_intens_scale_range, tgt_intens_offset_range, tgt_gaussian_noise_std, num_epochs, batch_size, epoch_size, seed, log_file, model_file, device): settings = locals().copy() import os import sys import pickle import cmdline_helpers if log_file == '': log_file = 'output_aug_log_{}.txt'.format(exp) elif log_file == 'none': log_file = None if log_file is not None: if os.path.exists(log_file): print('Output log file {} already exists'.format(log_file)) return use_rampup = rampup > 0 src_intens_scale_range_lower, src_intens_scale_range_upper, src_intens_offset_range_lower, src_intens_offset_range_upper = \ cmdline_helpers.intens_aug_options(src_intens_scale_range, src_intens_offset_range) tgt_intens_scale_range_lower, tgt_intens_scale_range_upper, tgt_intens_offset_range_lower, tgt_intens_offset_range_upper = \ cmdline_helpers.intens_aug_options(tgt_intens_scale_range, tgt_intens_offset_range) import time import math import numpy as np from batchup import data_source, work_pool import data_loaders import standardisation import network_architectures import augmentation import torch, torch.cuda from torch import nn from torch.nn import functional as F import optim_weight_ema torch_device = torch.device(device) pool = work_pool.WorkerThreadPool(2) n_chn = 0 if exp == 'svhn_mnist': d_source = data_loaders.load_svhn(zero_centre=False, greyscale=True) d_target = data_loaders.load_mnist(invert=False, zero_centre=False, pad32=True, val=False) elif exp == 'mnist_svhn': d_source = data_loaders.load_mnist(invert=False, zero_centre=False, pad32=True) d_target = data_loaders.load_svhn(zero_centre=False, greyscale=True, val=False) elif exp == 'svhn_mnist_rgb': d_source = data_loaders.load_svhn(zero_centre=False, greyscale=False) d_target = data_loaders.load_mnist(invert=False, zero_centre=False, pad32=True, val=False, rgb=True) elif exp == 'mnist_svhn_rgb': d_source = data_loaders.load_mnist(invert=False, zero_centre=False, pad32=True, rgb=True) d_target = data_loaders.load_svhn(zero_centre=False, greyscale=False, val=False) elif exp == 'cifar_stl': d_source = data_loaders.load_cifar10(range_01=False) d_target = data_loaders.load_stl(zero_centre=False, val=False) elif exp == 'stl_cifar': d_source = data_loaders.load_stl(zero_centre=False) d_target = data_loaders.load_cifar10(range_01=False, val=False) elif exp == 'mnist_usps': d_source = data_loaders.load_mnist(zero_centre=False) d_target = data_loaders.load_usps(zero_centre=False, scale28=True, val=False) elif exp == 'usps_mnist': d_source = data_loaders.load_usps(zero_centre=False, scale28=True) d_target = data_loaders.load_mnist(zero_centre=False, val=False) elif exp == 'syndigits_svhn': d_source = data_loaders.load_syn_digits(zero_centre=False) d_target = data_loaders.load_svhn(zero_centre=False, val=False) elif exp == 'synsigns_gtsrb': d_source = data_loaders.load_syn_signs(zero_centre=False) d_target = data_loaders.load_gtsrb(zero_centre=False, val=False) else: print('Unknown experiment type \'{}\''.format(exp)) return # Delete the training ground truths as we should not be using them del d_target.train_y if standardise_samples: standardisation.standardise_dataset(d_source) standardisation.standardise_dataset(d_target) n_classes = d_source.n_classes print('Loaded data') if arch == '': if exp in {'mnist_usps', 'usps_mnist'}: arch = 'mnist-bn-32-64-256' if exp in {'svhn_mnist', 'mnist_svhn'}: arch = 'grey-32-64-128-gp' if exp in { 'cifar_stl', 'stl_cifar', 'syndigits_svhn', 'svhn_mnist_rgb', 'mnist_svhn_rgb' }: arch = 'rgb-128-256-down-gp' if exp in {'synsigns_gtsrb'}: arch = 'rgb40-96-192-384-gp' net_class, expected_shape = network_architectures.get_net_and_shape_for_architecture( arch) if expected_shape != d_source.train_X.shape[1:]: print( 'Architecture {} not compatible with experiment {}; it needs samples of shape {}, ' 'data has samples of shape {}'.format(arch, exp, expected_shape, d_source.train_X.shape[1:])) return student_net = net_class(n_classes).to(torch_device) teacher_net = net_class(n_classes).to(torch_device) student_params = list(student_net.parameters()) teacher_params = list(teacher_net.parameters()) for param in teacher_params: param.requires_grad = False student_optimizer = torch.optim.Adam(student_params, lr=learning_rate) if fix_ema: teacher_optimizer = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, alpha=teacher_alpha) else: teacher_optimizer = optim_weight_ema.OldWeightEMA(teacher_net, student_net, alpha=teacher_alpha) classification_criterion = nn.CrossEntropyLoss() print('Built network') src_aug = augmentation.ImageAugmentation( src_hflip, src_xlat_range, src_affine_std, intens_flip=src_intens_flip, intens_scale_range_lower=src_intens_scale_range_lower, intens_scale_range_upper=src_intens_scale_range_upper, intens_offset_range_lower=src_intens_offset_range_lower, intens_offset_range_upper=src_intens_offset_range_upper, gaussian_noise_std=src_gaussian_noise_std) tgt_aug = augmentation.ImageAugmentation( tgt_hflip, tgt_xlat_range, tgt_affine_std, intens_flip=tgt_intens_flip, intens_scale_range_lower=tgt_intens_scale_range_lower, intens_scale_range_upper=tgt_intens_scale_range_upper, intens_offset_range_lower=tgt_intens_offset_range_lower, intens_offset_range_upper=tgt_intens_offset_range_upper, gaussian_noise_std=tgt_gaussian_noise_std) if combine_batches: def augment(X_sup, y_src, X_tgt): X_src_stu, X_src_tea = src_aug.augment_pair(X_sup) X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt) return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea else: def augment(X_src, y_src, X_tgt): X_src = src_aug.augment(X_src) X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt) return X_src, y_src, X_tgt_stu, X_tgt_tea rampup_weight_in_list = [0] cls_bal_fn = network_architectures.get_cls_bal_function(cls_balance_loss) def compute_aug_loss(stu_out, tea_out): # Augmentation loss if use_rampup: unsup_mask = None conf_mask_count = None unsup_mask_count = None else: conf_tea = torch.max(tea_out, 1)[0] unsup_mask = conf_mask = (conf_tea > confidence_thresh).float() unsup_mask_count = conf_mask_count = conf_mask.sum() if loss == 'bce': aug_loss = network_architectures.robust_binary_crossentropy( stu_out, tea_out) else: d_aug_loss = stu_out - tea_out aug_loss = d_aug_loss * d_aug_loss # Class balance scaling if cls_bal_scale: if use_rampup: n_samples = float(aug_loss.shape[0]) else: n_samples = unsup_mask.sum() avg_pred = n_samples / float(n_classes) bal_scale = avg_pred / torch.clamp(tea_out.sum(dim=0), min=1.0) if cls_bal_scale_range != 0.0: bal_scale = torch.clamp(bal_scale, min=1.0 / cls_bal_scale_range, max=cls_bal_scale_range) bal_scale = bal_scale.detach() aug_loss = aug_loss * bal_scale[None, :] aug_loss = aug_loss.mean(dim=1) if use_rampup: unsup_loss = aug_loss.mean() * rampup_weight_in_list[0] else: unsup_loss = (aug_loss * unsup_mask).mean() # Class balance loss if cls_balance > 0.0: # Compute per-sample average predicated probability # Average over samples to get average class prediction avg_cls_prob = stu_out.mean(dim=0) # Compute loss equalise_cls_loss = cls_bal_fn(avg_cls_prob, float(1.0 / n_classes)) equalise_cls_loss = equalise_cls_loss.mean() * n_classes if use_rampup: equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[0] else: if rampup == 0: equalise_cls_loss = equalise_cls_loss * unsup_mask.mean( dim=0) unsup_loss += equalise_cls_loss * cls_balance return unsup_loss, conf_mask_count, unsup_mask_count if combine_batches: def f_train(X_src0, X_src1, y_src, X_tgt0, X_tgt1): X_src0 = torch.tensor(X_src0, dtype=torch.float, device=torch_device) X_src1 = torch.tensor(X_src1, dtype=torch.float, device=torch_device) y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device) X_tgt0 = torch.tensor(X_tgt0, dtype=torch.float, device=torch_device) X_tgt1 = torch.tensor(X_tgt1, dtype=torch.float, device=torch_device) n_samples = X_src0.size()[0] n_total = n_samples + X_tgt0.size()[0] student_optimizer.zero_grad() student_net.train() teacher_net.train() # Concatenate source and target mini-batches X0 = torch.cat([X_src0, X_tgt0], 0) X1 = torch.cat([X_src1, X_tgt1], 0) student_logits_out = student_net(X0) student_prob_out = F.softmax(student_logits_out, dim=1) src_logits_out = student_logits_out[:n_samples] src_prob_out = student_prob_out[:n_samples] teacher_logits_out = teacher_net(X1) teacher_prob_out = F.softmax(teacher_logits_out, dim=1) # Supervised classification loss if double_softmax: clf_loss = classification_criterion(src_prob_out, y_src) else: clf_loss = classification_criterion(src_logits_out, y_src) unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss( student_prob_out, teacher_prob_out) loss_expr = clf_loss + unsup_loss * unsup_weight loss_expr.backward() student_optimizer.step() teacher_optimizer.step() outputs = [ float(clf_loss) * n_samples, float(unsup_loss) * n_total ] if not use_rampup: mask_count = float(conf_mask_count) * 0.5 unsup_count = float(unsup_mask_count) * 0.5 outputs.append(mask_count) outputs.append(unsup_count) return tuple(outputs) else: def f_train(X_src, y_src, X_tgt0, X_tgt1): X_src = torch.tensor(X_src, dtype=torch.float, device=torch_device) y_src = torch.tensor(y_src, dtype=torch.long, device=torch_device) X_tgt0 = torch.tensor(X_tgt0, dtype=torch.float, device=torch_device) X_tgt1 = torch.tensor(X_tgt1, dtype=torch.float, device=torch_device) student_optimizer.zero_grad() student_net.train() teacher_net.train() src_logits_out = student_net(X_src) student_tgt_logits_out = student_net(X_tgt0) student_tgt_prob_out = F.softmax(student_tgt_logits_out, dim=1) teacher_tgt_logits_out = teacher_net(X_tgt1) teacher_tgt_prob_out = F.softmax(teacher_tgt_logits_out, dim=1) # Supervised classification loss if double_softmax: clf_loss = classification_criterion( F.softmax(src_logits_out, dim=1), y_src) else: clf_loss = classification_criterion(src_logits_out, y_src) unsup_loss, conf_mask_count, unsup_mask_count = compute_aug_loss( student_tgt_prob_out, teacher_tgt_prob_out) loss_expr = clf_loss + unsup_loss * unsup_weight loss_expr.backward() student_optimizer.step() teacher_optimizer.step() n_samples = X_src.size()[0] outputs = [ float(clf_loss) * n_samples, float(unsup_loss) * n_samples ] if not use_rampup: mask_count = float(conf_mask_count) unsup_count = float(unsup_mask_count) outputs.append(mask_count) outputs.append(unsup_count) return tuple(outputs) print('Compiled training function') def f_pred_src(X_sup): X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device) student_net.eval() teacher_net.eval() return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(), F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy()) def f_pred_tgt(X_sup): X_var = torch.tensor(X_sup, dtype=torch.float, device=torch_device) student_net.eval() teacher_net.eval() return (F.softmax(student_net(X_var), dim=1).detach().cpu().numpy(), F.softmax(teacher_net(X_var), dim=1).detach().cpu().numpy()) def f_eval_src(X_sup, y_sup): y_pred_prob_stu, y_pred_prob_tea = f_pred_src(X_sup) y_pred_stu = np.argmax(y_pred_prob_stu, axis=1) y_pred_tea = np.argmax(y_pred_prob_tea, axis=1) return (float( (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum())) def f_eval_tgt(X_sup, y_sup): y_pred_prob_stu, y_pred_prob_tea = f_pred_tgt(X_sup) y_pred_stu = np.argmax(y_pred_prob_stu, axis=1) y_pred_tea = np.argmax(y_pred_prob_tea, axis=1) return (float( (y_pred_stu != y_sup).sum()), float((y_pred_tea != y_sup).sum())) print('Compiled evaluation function') # Setup output def log(text): print(text) if log_file is not None: with open(log_file, 'a') as f: f.write(text + '\n') f.flush() f.close() cmdline_helpers.ensure_containing_dir_exists(log_file) # Report setttings log('Settings: {}'.format(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ]))) # Report dataset size log('Dataset:') log('SOURCE Train: X.shape={}, y.shape={}'.format(d_source.train_X.shape, d_source.train_y.shape)) log('SOURCE Test: X.shape={}, y.shape={}'.format(d_source.test_X.shape, d_source.test_y.shape)) log('TARGET Train: X.shape={}'.format(d_target.train_X.shape)) log('TARGET Test: X.shape={}, y.shape={}'.format(d_target.test_X.shape, d_target.test_y.shape)) print('Training...') sup_ds = data_source.ArrayDataSource([d_source.train_X, d_source.train_y], repeats=-1) tgt_train_ds = data_source.ArrayDataSource([d_target.train_X], repeats=-1) train_ds = data_source.CompositeDataSource([sup_ds, tgt_train_ds]).map(augment) train_ds = pool.parallel_data_source(train_ds) if epoch_size == 'large': n_samples = max(d_source.train_X.shape[0], d_target.train_X.shape[0]) elif epoch_size == 'small': n_samples = min(d_source.train_X.shape[0], d_target.train_X.shape[0]) elif epoch_size == 'target': n_samples = d_target.train_X.shape[0] n_train_batches = n_samples // batch_size source_test_ds = data_source.ArrayDataSource( [d_source.test_X, d_source.test_y]) target_test_ds = data_source.ArrayDataSource( [d_target.test_X, d_target.test_y]) if seed != 0: shuffle_rng = np.random.RandomState(seed) else: shuffle_rng = np.random train_batch_iter = train_ds.batch_iterator(batch_size=batch_size, shuffle=shuffle_rng) best_teacher_model_state = { k: v.cpu().numpy() for k, v in teacher_net.state_dict().items() } best_conf_mask_rate = 0.0 best_src_test_err = 1.0 for epoch in range(num_epochs): t1 = time.time() if use_rampup: if epoch < rampup: p = max(0.0, float(epoch)) / float(rampup) p = 1.0 - p rampup_value = math.exp(-p * p * 5.0) else: rampup_value = 1.0 rampup_weight_in_list[0] = rampup_value train_res = data_source.batch_map_mean(f_train, train_batch_iter, n_batches=n_train_batches) train_clf_loss = train_res[0] if combine_batches: unsup_loss_string = 'unsup (both) loss={:.6f}'.format(train_res[1]) else: unsup_loss_string = 'unsup (tgt) loss={:.6f}'.format(train_res[1]) src_test_err_stu, src_test_err_tea = source_test_ds.batch_map_mean( f_eval_src, batch_size=batch_size * 2) tgt_test_err_stu, tgt_test_err_tea = target_test_ds.batch_map_mean( f_eval_tgt, batch_size=batch_size * 2) if use_rampup: unsup_loss_string = '{}, rampup={:.3%}'.format( unsup_loss_string, rampup_value) if src_test_err_stu < best_src_test_err: best_src_test_err = src_test_err_stu best_teacher_model_state = { k: v.cpu().numpy() for k, v in teacher_net.state_dict().items() } improve = '*** ' else: improve = '' else: conf_mask_rate = train_res[-2] unsup_mask_rate = train_res[-1] if conf_mask_rate > best_conf_mask_rate: best_conf_mask_rate = conf_mask_rate improve = '*** ' best_teacher_model_state = { k: v.cpu().numpy() for k, v in teacher_net.state_dict().items() } else: improve = '' unsup_loss_string = '{}, conf mask={:.3%}, unsup mask={:.3%}'.format( unsup_loss_string, conf_mask_rate, unsup_mask_rate) t2 = time.time() log('{}Epoch {} took {:.2f}s: TRAIN clf loss={:.6f}, {}; ' 'SRC TEST ERR={:.3%}, TGT TEST student err={:.3%}, TGT TEST teacher err={:.3%}' .format(improve, epoch, t2 - t1, train_clf_loss, unsup_loss_string, src_test_err_stu, tgt_test_err_stu, tgt_test_err_tea)) # Save network if model_file != '': cmdline_helpers.ensure_containing_dir_exists(model_file) with open(model_file, 'wb') as f: torch.save(best_teacher_model_state, f)
def train_seg_semisup_aug_mt( submit_config: job_helper.SubmitConfig, dataset, model, arch, freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay, learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power, teacher_alpha, bin_fill_holes, crop_size, aug_offset_range, aug_hflip, aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform, aug_rot_mag, aug_free_scale_rot, cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed, save_preds, save_model, num_workers): settings = locals().copy() del settings['submit_config'] import os import math import time import itertools import numpy as np import torch.nn as nn, torch.nn.functional as F from architectures import network_architectures import torch.utils.data from datapipe import datasets from datapipe import seg_data, seg_transforms, seg_transforms_cv import evaluation import optim_weight_ema import lr_schedules from datapipe import torch_utils affine_align_corners_kw = torch_utils.affine_align_corners_kw(True) if crop_size == '': crop_size = None else: crop_size = [int(x.strip()) for x in crop_size.split(',')] torch_device = torch.device('cuda:0') # # Load data sets # ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup, split_seed, split_path) ds_src = ds_dict['ds_src'] ds_tgt = ds_dict['ds_tgt'] tgt_val_ndx = ds_dict['val_ndx_tgt'] src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None test_ndx = ds_dict['test_ndx_tgt'] sup_ndx = ds_dict['sup_ndx'] unsup_ndx = ds_dict['unsup_ndx'] n_classes = ds_src.num_classes root_n_classes = math.sqrt(n_classes) if bin_fill_holes and n_classes != 2: print( 'Binary hole filling can only be used with binary (2-class) segmentation datasets' ) return print('Loaded data') # Build network NetClass = network_architectures.seg.get(arch) student_net = NetClass(ds_src.num_classes).to(torch_device) if opt_type == 'adam': student_optim = torch.optim.Adam([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ]) elif opt_type == 'sgd': student_optim = torch.optim.SGD([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ], momentum=sgd_momentum, nesterov=sgd_nesterov, weight_decay=sgd_weight_decay) else: raise ValueError('Unknown opt_type {}'.format(opt_type)) if model == 'mean_teacher': teacher_net = NetClass(ds_src.num_classes).to(torch_device) for p in teacher_net.parameters(): p.requires_grad = False teacher_optim = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, teacher_alpha) eval_net = teacher_net elif model == 'pi': teacher_net = student_net teacher_optim = None eval_net = student_net else: print('Unknown model type {}'.format(model)) return BLOCK_SIZE = student_net.BLOCK_SIZE NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net) if freeze_bn: if not hasattr(student_net, 'freeze_batchnorm'): raise ValueError( 'Network {} does not support batchnorm freezing'.format(arch)) clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255) print('Built network') if iters_per_epoch == -1: iters_per_epoch = len(unsup_ndx) // batch_size total_iters = iters_per_epoch * num_epochs lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers( optimizer=student_optim, total_iters=total_iters, schedule_type=lr_sched, step_epochs=lr_step_epochs, step_gamma=lr_step_gamma, poly_power=lr_poly_power) # Train data pipeline: transforms train_transforms = [] if crop_size is not None: if aug_scale_hung: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropScaleHung( crop_size, (aug_offset_range, aug_offset_range), uniform_scale=not aug_scale_non_uniform)) elif aug_max_scale != 1.0 or aug_rot_mag != 0.0: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropRotateScale( crop_size, (aug_offset_range, aug_offset_range), rot_mag=aug_rot_mag, max_scale=aug_max_scale, uniform_scale=not aug_scale_non_uniform, constrain_rot_scale=not aug_free_scale_rot)) else: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCrop( crop_size, (aug_offset_range, aug_offset_range))) else: if aug_scale_hung: raise NotImplementedError('aug_scale_hung requires a crop_size') if aug_hflip or aug_vflip or aug_hvflip: train_transforms.append( seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip, aug_hvflip)) train_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) # Train data pipeline: supervised and unsupervised data sets train_sup_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') train_unsup_ds = ds_src.dataset( labels=False, mask=True, xf=True, pair=True, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') collate_fn = seg_data.SegCollate(BLOCK_SIZE) # Train data pipeline: data loaders sup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(sup_ndx)) train_sup_loader = torch.utils.data.DataLoader(train_sup_ds, batch_size, sampler=sup_sampler, collate_fn=collate_fn, num_workers=num_workers) if cons_weight > 0.0: unsup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(unsup_ndx)) train_unsup_loader = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=collate_fn, num_workers=num_workers) else: train_unsup_loader = None # Eval pipeline src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline( ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size, collate_fn, NET_MEAN, NET_STD, num_workers) # Report setttings print('Settings:') print(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ])) # Report dataset size print('Dataset:') print('len(sup_ndx)={}'.format(len(sup_ndx))) print('len(unsup_ndx)={}'.format(len(unsup_ndx))) if ds_src is not ds_tgt: print('len(src_val_ndx)={}'.format(len(tgt_val_ndx))) print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx))) else: print('len(val_ndx)={}'.format(len(tgt_val_ndx))) if test_ndx is not None: print('len(test_ndx)={}'.format(len(test_ndx))) if n_sup != -1: print('sup_ndx={}'.format(sup_ndx.tolist())) # Track mIoU for early stopping best_tgt_miou = None best_epoch = 0 eval_net_state = { key: value.detach().cpu().numpy() for key, value in eval_net.state_dict().items() } # Create iterators train_sup_iter = iter(train_sup_loader) train_unsup_iter = iter( train_unsup_loader) if train_unsup_loader is not None else None iter_i = 0 print('Training...') for epoch_i in range(num_epochs): if lr_epoch_scheduler is not None: lr_epoch_scheduler.step(epoch_i) t1 = time.time() if rampup > 0: ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup) else: ramp_val = 1.0 student_net.train() if teacher_net is not student_net: teacher_net.train() if freeze_bn: student_net.freeze_batchnorm() if teacher_net is not student_net: teacher_net.freeze_batchnorm() sup_loss_acc = 0.0 consistency_loss_acc = 0.0 conf_rate_acc = 0.0 n_sup_batches = 0 n_unsup_batches = 0 src_val_iter = iter( src_val_loader) if src_val_loader is not None else None tgt_val_iter = iter( tgt_val_loader) if tgt_val_loader is not None else None for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch): if lr_iter_scheduler is not None: lr_iter_scheduler.step(iter_i) student_optim.zero_grad() # # Supervised branch # batch_x = sup_batch['image'].to(torch_device) batch_y = sup_batch['labels'].to(torch_device) logits_sup = student_net(batch_x) sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :]) sup_loss.backward() if cons_weight > 0.0: for _ in range(unsup_batch_ratio): # # Unsupervised branch # # Cut mode: batch consists of unsupervised samples and mask params unsup_batch = next(train_unsup_iter) # Input images to torch tensor batch_ux0 = unsup_batch['sample0']['image'].to( torch_device) batch_um0 = unsup_batch['sample0']['mask'].to(torch_device) batch_ux1 = unsup_batch['sample1']['image'].to( torch_device) batch_um1 = unsup_batch['sample1']['mask'].to(torch_device) batch_ufx0_to_1 = unsup_batch['xf0_to_1'].to(torch_device) # Get teacher predictions for image0 with torch.no_grad(): logits_cons_tea = teacher_net(batch_ux0).detach() # Get student prediction for image1 logits_cons_stu = student_net(batch_ux1) # Transformation from teacher to student space grid_tea_to_stu = F.affine_grid(batch_ufx0_to_1, batch_ux0.shape, **affine_align_corners_kw) # Transform teacher predicted logits to student space logits_cons_tea_in_stu = F.grid_sample( logits_cons_tea, grid_tea_to_stu, **affine_align_corners_kw) # Transform mask from teacher to student space and multiply by student space mask mask_tea_in_stu = F.grid_sample( batch_um0, grid_tea_to_stu, ** affine_align_corners_kw) * batch_um1 # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) # Transform teacher predicted probabilities to student space prob_cons_tea_in_stu = F.grid_sample( prob_cons_tea, grid_tea_to_stu, **affine_align_corners_kw) # for i in range(len(batch_ux0)): # plt.figure(figsize=(18, 12)) # # x_0_in_1 = F.grid_sample(batch_ux0, grid_tea_to_stu) # d_x0_in_1 = torch.abs(x_0_in_1 - batch_ux1) * mask_tea_in_stu # mask_tea_in_stu_np = mask_tea_in_stu.detach().cpu().numpy() # # plt.subplot(2, 4, 1) # plt.imshow(batch_ux0[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5) # plt.subplot(2, 4, 2) # plt.imshow(batch_ux1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5) # plt.subplot(2, 4, 3) # plt.imshow(x_0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 0.25 + 0.5) # plt.subplot(2, 4, 4) # plt.imshow(d_x0_in_1[i].detach().cpu().numpy().transpose(1, 2, 0) * 10 + 0.5, cmap='gray') # # plt.subplot(2, 4, 5) # plt.imshow(batch_um0[i,0].detach().cpu().numpy(), cmap='gray') # plt.subplot(2, 4, 6) # plt.imshow(batch_um1[i,0].detach().cpu().numpy(), cmap='gray') # plt.subplot(2, 4, 7) # plt.imshow(mask_tea_in_stu[i,0].detach().cpu().numpy(), cmap='gray') # # plt.show() loss_mask = mask_tea_in_stu # Confidence thresholding if conf_thresh > 0.0: # Compute confidence of teacher predictions conf_tea = prob_cons_tea_in_stu.max(dim=1)[0] # Compute confidence mask conf_mask = (conf_tea >= conf_thresh).float()[:, None, :, :] # Record rate for reporting conf_rate_acc += float(conf_mask.mean()) # Average confidence mask if requested if not conf_per_pixel: conf_mask = conf_mask.mean() loss_mask = loss_mask * conf_mask elif rampup > 0: conf_rate_acc += ramp_val # Compute per-pixel consistency loss # Note that the way we aggregate the loss across the class/channel dimension (1) # depends on the loss function used. Generally, summing over the class dimension # keeps the magnitude of the gradient of the loss w.r.t. the logits # nearly constant w.r.t. the number of classes. When using logit-variance, # dividing by `sqrt(num_classes)` helps. if cons_loss_fn == 'var': delta_prob = prob_cons_stu - prob_cons_tea_in_stu consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'logits_var': delta_logits = logits_cons_stu - logits_cons_tea_in_stu consistency_loss = delta_logits * delta_logits consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'logits_smoothl1': consistency_loss = F.smooth_l1_loss( logits_cons_stu, logits_cons_tea_in_stu, reduce=False) consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'bce': consistency_loss = network_architectures.robust_binary_crossentropy( prob_cons_stu, prob_cons_tea_in_stu) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'kld': consistency_loss = F.kl_div(F.log_softmax( logits_cons_stu, dim=1), prob_cons_tea_in_stu, reduce=False) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) else: raise ValueError( 'Unknown consistency loss function {}'.format( cons_loss_fn)) # Apply consistency loss mask and take the mean over pixels and images consistency_loss = (consistency_loss * loss_mask).mean() # Modulate with rampup if desired if rampup > 0: consistency_loss = consistency_loss * ramp_val # Weight the consistency loss and back-prop unsup_loss = consistency_loss * cons_weight unsup_loss.backward() consistency_loss_acc += float(consistency_loss.detach()) n_unsup_batches += 1 student_optim.step() if teacher_optim is not None: teacher_optim.step() sup_loss_acc += float(sup_loss.detach()) n_sup_batches += 1 iter_i += 1 sup_loss_acc /= n_sup_batches if n_unsup_batches > 0: consistency_loss_acc /= n_unsup_batches conf_rate_acc /= n_unsup_batches eval_net.eval() if src_val_iter is not None: src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes, bin_fill_holes) with torch.no_grad(): for batch in src_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): src_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) src_iou = src_iou_eval.score() src_miou = src_iou.mean() else: src_iou_eval = src_iou = src_miou = None tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in tgt_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): tgt_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) tgt_iou = tgt_iou_eval.score() tgt_miou = tgt_iou.mean() t2 = time.time() if ds_src is not ds_tgt: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, ' 'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format( epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, src_miou, tgt_miou)) print('-- SRC {}'.format(', '.join( ['{:.3%}'.format(x) for x in src_iou]))) print('-- TGT {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) else: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}' .format(epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, tgt_miou)) print('-- {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) if save_model: model_path = os.path.join(submit_config.run_dir, "model.pth") torch.save(eval_net, model_path) if save_preds: out_dir = os.path.join(submit_config.run_dir, 'preds') os.makedirs(out_dir, exist_ok=True) with torch.no_grad(): for batch in tgt_val_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) else: out_dir = None if test_loader is not None: test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in test_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): if save_preds: ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) test_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) test_iou = test_iou_eval.score() test_miou = test_iou.mean() print('FINAL TEST: mIoU={:.3%}'.format(test_miou)) print('-- TEST {}'.format(', '.join( ['{:.3%}'.format(x) for x in test_iou])))
def train_seg_semisup_vat_mt( submit_config: job_helper.SubmitConfig, dataset, model, arch, freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay, learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power, teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform, aug_rot_mag, vat_radius, adaptive_vat_radius, vat_dir_from_student, cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed, save_preds, save_model, num_workers): settings = locals().copy() del settings['submit_config'] import os import time import itertools import math import numpy as np import torch, torch.nn as nn, torch.nn.functional as F from architectures import network_architectures import torch.utils.data from datapipe import datasets from datapipe import seg_data, seg_transforms, seg_transforms_cv import evaluation import optim_weight_ema import lr_schedules if crop_size == '': crop_size = None else: crop_size = [int(x.strip()) for x in crop_size.split(',')] torch_device = torch.device('cuda:0') # # Load data sets # ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup, split_seed, split_path) ds_src = ds_dict['ds_src'] ds_tgt = ds_dict['ds_tgt'] tgt_val_ndx = ds_dict['val_ndx_tgt'] src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None test_ndx = ds_dict['test_ndx_tgt'] sup_ndx = ds_dict['sup_ndx'] unsup_ndx = ds_dict['unsup_ndx'] n_classes = ds_src.num_classes root_n_classes = math.sqrt(n_classes) if bin_fill_holes and n_classes != 2: print( 'Binary hole filling can only be used with binary (2-class) segmentation datasets' ) return print('Loaded data') # Build network NetClass = network_architectures.seg.get(arch) student_net = NetClass(ds_src.num_classes).to(torch_device) if opt_type == 'adam': student_optim = torch.optim.Adam([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ]) elif opt_type == 'sgd': student_optim = torch.optim.SGD([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ], momentum=sgd_momentum, nesterov=sgd_nesterov, weight_decay=sgd_weight_decay) else: raise ValueError('Unknown opt_type {}'.format(opt_type)) if model == 'mean_teacher': teacher_net = NetClass(ds_src.num_classes).to(torch_device) for p in teacher_net.parameters(): p.requires_grad = False teacher_optim = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, teacher_alpha) eval_net = teacher_net elif model == 'pi': teacher_net = student_net teacher_optim = None eval_net = student_net else: print('Unknown model type {}'.format(model)) return if vat_dir_from_student: vat_dir_net = student_net else: vat_dir_net = teacher_net BLOCK_SIZE = student_net.BLOCK_SIZE NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net) if freeze_bn: if not hasattr(student_net, 'freeze_batchnorm'): raise ValueError( 'Network {} does not support batchnorm freezing'.format(arch)) clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255) print('Built network') if iters_per_epoch == -1: iters_per_epoch = len(unsup_ndx) // batch_size total_iters = iters_per_epoch * num_epochs lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers( optimizer=student_optim, total_iters=total_iters, schedule_type=lr_sched, step_epochs=lr_step_epochs, step_gamma=lr_step_gamma, poly_power=lr_poly_power) # Train data pipeline: transforms train_transforms = [] if crop_size is not None: if aug_scale_hung: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropScaleHung( crop_size, (0, 0), uniform_scale=not aug_scale_non_uniform)) elif aug_max_scale != 1.0 or aug_rot_mag != 0.0: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropRotateScale( crop_size, (0, 0), rot_mag=aug_rot_mag, max_scale=aug_max_scale, uniform_scale=not aug_scale_non_uniform, constrain_rot_scale=True)) else: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0))) else: if aug_scale_hung: raise NotImplementedError('aug_scale_hung requires a crop_size') if aug_hflip or aug_vflip or aug_hvflip: train_transforms.append( seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip, aug_hvflip)) train_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) # Train data pipeline: supervised and unsupervised data sets train_sup_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') train_unsup_ds = ds_src.dataset( labels=False, mask=True, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') collate_fn = seg_data.SegCollate(BLOCK_SIZE) # Train data pipeline: data loaders sup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(sup_ndx)) train_sup_loader = torch.utils.data.DataLoader(train_sup_ds, batch_size, sampler=sup_sampler, collate_fn=collate_fn, num_workers=num_workers) if cons_weight > 0.0: unsup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(unsup_ndx)) train_unsup_loader = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=collate_fn, num_workers=num_workers) else: train_unsup_loader = None # Eval pipeline src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline( ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size, collate_fn, NET_MEAN, NET_STD, num_workers) # Report setttings print('Settings:') print(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ])) # Report dataset size print('Dataset:') print('len(sup_ndx)={}'.format(len(sup_ndx))) print('len(unsup_ndx)={}'.format(len(unsup_ndx))) if ds_src is not ds_tgt: print('len(src_val_ndx)={}'.format(len(tgt_val_ndx))) print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx))) else: print('len(val_ndx)={}'.format(len(tgt_val_ndx))) if test_ndx is not None: print('len(test_ndx)={}'.format(len(test_ndx))) if n_sup != -1: print('sup_ndx={}'.format(sup_ndx.tolist())) def t_dot(a, b): return (a * b).sum(dim=1, keepdim=True) def normalize_eps(x): x_flat = x.view(len(x), -1) mag = torch.sqrt((x_flat * x_flat).sum(dim=1)) return x / (mag[:, None, None, None] + 1e-12) def normalized_noise(x, requires_grad=False, scale=1.0): eps = torch.randn(x.shape, dtype=torch.float, device=x.device) eps = normalize_eps(eps) * scale if requires_grad: eps = eps.clone().detach().requires_grad_(True) return eps def vat_direction(x): """ Compute the VAT perturbation direction vector :param x: input image as a `(N, C, H, W)` tensor :return: VAT direction as a `(N, C, H, W)` tensor """ # Put the network used to get the VAT direction in eval mode and get the predicted # logits and probabilities for the batch of samples x vat_dir_net.eval() with torch.no_grad(): y_pred_logits = vat_dir_net(x).detach() y_pred_prob = F.softmax(y_pred_logits, dim=1) # Initial noise offset vector with requires_grad=True noise_scale = 1.0e-6 * x.shape[2] * x.shape[3] / 1000 eps = normalized_noise(x, requires_grad=True, scale=noise_scale) # Predict logits and probs for sample perturbed by eps eps_pred_logits = vat_dir_net(x.detach() + eps) eps_pred_prob = F.softmax(eps_pred_logits, dim=1) # Choose our loss function if cons_loss_fn == 'var': delta = (eps_pred_prob - y_pred_prob) loss = (delta * delta).sum() elif cons_loss_fn == 'bce': loss = network_architectures.robust_binary_crossentropy( eps_pred_prob, y_pred_prob).sum() elif cons_loss_fn == 'kld': loss = F.kl_div(F.log_softmax(eps_pred_logits, dim=1), y_pred_prob, reduce=False).sum() elif cons_loss_fn == 'logits_var': delta = (eps_pred_logits - y_pred_logits) loss = (delta * delta).sum() else: raise ValueError( 'Unknown consistency loss function {}'.format(cons_loss_fn)) # Differentiate the loss w.r.t. the perturbation eps_adv = torch.autograd.grad(outputs=loss, inputs=eps, create_graph=True, retain_graph=True, only_inputs=True)[0] # Normalize the adversarial perturbation return normalize_eps(eps_adv), y_pred_logits, y_pred_prob def vat_perburbation(x, m): eps_adv_nrm, y_pred_logits, y_pred_prob = vat_direction(x) if adaptive_vat_radius: # We view semantic segmentation as predicting the class of a pixel # given a patch centred on that pixel. # The most similar patch in terms of pixel content to a patch P # is a patch Q whose central pixel is an immediate neighbour # of the central pixel P. # We therefore use the image Jacobian (gradient w.r.t. x and y) to # get a sense of the distance between neighbouring patches # so we can scale the VAT radius according to the image content. # Delta in vertical and horizontal directions delta_v = x[:, :, 2:, :] - x[:, :, :-2, :] delta_h = x[:, :, :, 2:] - x[:, :, :, :-2] # delta_h and delta_v are the difference between pixels where the step size is 2, rather than 1 # So divide by 2 to get the magnitude of the Jacobian delta_v = delta_v.view(len(delta_v), -1) delta_h = delta_h.view(len(delta_h), -1) adv_radius = vat_radius * torch.sqrt( (delta_v**2).sum(dim=1) + (delta_h**2).sum(dim=1))[:, None, None, None] * 0.5 else: scale = math.sqrt(float(x.shape[1] * x.shape[2] * x.shape[3])) adv_radius = vat_radius * scale return (eps_adv_nrm * adv_radius).detach(), y_pred_logits, y_pred_prob # Track mIoU for early stopping best_tgt_miou = None best_epoch = 0 eval_net_state = { key: value.detach().cpu().numpy() for key, value in eval_net.state_dict().items() } # Create iterators train_sup_iter = iter(train_sup_loader) train_unsup_iter = iter( train_unsup_loader) if train_unsup_loader is not None else None iter_i = 0 print('Training...') for epoch_i in range(num_epochs): if lr_epoch_scheduler is not None: lr_epoch_scheduler.step(epoch_i) t1 = time.time() if rampup > 0: ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup) else: ramp_val = 1.0 student_net.train() if teacher_net is not student_net: teacher_net.train() if freeze_bn: student_net.freeze_batchnorm() if teacher_net is not student_net: teacher_net.freeze_batchnorm() sup_loss_acc = 0.0 consistency_loss_acc = 0.0 conf_rate_acc = 0.0 n_sup_batches = 0 n_unsup_batches = 0 src_val_iter = iter( src_val_loader) if src_val_loader is not None else None tgt_val_iter = iter( tgt_val_loader) if tgt_val_loader is not None else None for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch): if lr_iter_scheduler is not None: lr_iter_scheduler.step(iter_i) student_optim.zero_grad() # # Supervised branch # batch_x = sup_batch['image'].to(torch_device) batch_y = sup_batch['labels'].to(torch_device) logits_sup = student_net(batch_x) sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :]) sup_loss.backward() if cons_weight > 0.0: for _ in range(unsup_batch_ratio): # # Unsupervised branch # unsup_batch = next(train_unsup_iter) # Input images to torch tensor batch_ux = unsup_batch['image'].to(torch_device) batch_um = unsup_batch['mask'].to(torch_device) # batch_um is a mask that is 1 for valid pixels, 0 for invalid pixels. # It us used later on to scale the consistency loss, so that consistency loss is # only computed for valid pixels. # Explanation: # When using geometric augmentations such as rotations, some pixels in the training # crop may come from outside the bounds of the input image. These pixels will have a value # of 0 in these masks. Similarly, when using scaled crops, the size of the crop # from the input image that must be scaled to the size of the training crop may be # larger than one/both of the input image dimensions. Pixels in the training crop # that arise from outside the input image bounds will once again be given a value # of 0 in these masks. # Compute VAT perburbation x_perturb, logits_cons_tea, prob_cons_tea = vat_perburbation( batch_ux, batch_um) # Perturb image batch_ux_adv = batch_ux + x_perturb # Get teacher predictions for original image with torch.no_grad(): logits_cons_tea = teacher_net(batch_ux).detach() # Get student prediction for cut image logits_cons_stu = student_net(batch_ux_adv) # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) loss_mask = batch_um # Confidence thresholding if conf_thresh > 0.0: # Compute confidence of teacher predictions conf_tea = prob_cons_tea.max(dim=1)[0] # Compute confidence mask conf_mask = (conf_tea >= conf_thresh).float()[:, None, :, :] # Record rate for reporting conf_rate_acc += float(conf_mask.mean()) # Average confidence mask if requested if not conf_per_pixel: conf_mask = conf_mask.mean() loss_mask = loss_mask * conf_mask elif rampup > 0: conf_rate_acc += ramp_val # Compute per-pixel consistency loss # Note that the way we aggregate the loss across the class/channel dimension (1) # depends on the loss function used. Generally, summing over the class dimension # keeps the magnitude of the gradient of the loss w.r.t. the logits # nearly constant w.r.t. the number of classes. When using logit-variance, # dividing by `sqrt(num_classes)` helps. if cons_loss_fn == 'var': delta_prob = prob_cons_stu - prob_cons_tea consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'logits_var': delta_logits = logits_cons_stu - logits_cons_tea consistency_loss = delta_logits * delta_logits consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'bce': consistency_loss = network_architectures.robust_binary_crossentropy( prob_cons_stu, prob_cons_tea) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'kld': consistency_loss = F.kl_div(F.log_softmax( logits_cons_stu, dim=1), prob_cons_tea, reduce=False) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) else: raise ValueError( 'Unknown consistency loss function {}'.format( cons_loss_fn)) # Apply consistency loss mask and take the mean over pixels and images consistency_loss = (consistency_loss * loss_mask).mean() # Modulate with rampup if desired if rampup > 0: consistency_loss = consistency_loss * ramp_val # Weight the consistency loss and back-prop unsup_loss = consistency_loss * cons_weight unsup_loss.backward() consistency_loss_val = float(consistency_loss.detach()) consistency_loss_acc += consistency_loss_val if np.isnan(consistency_loss_val): print( 'NaN detected in consistency loss; bailing out...') return n_unsup_batches += 1 student_optim.step() if teacher_optim is not None: teacher_optim.step() sup_loss_val = float(sup_loss.detach()) sup_loss_acc += sup_loss_val if np.isnan(sup_loss_val): print('NaN detected in supervised loss; bailing out...') return n_sup_batches += 1 iter_i += 1 sup_loss_acc /= n_sup_batches if n_unsup_batches > 0: consistency_loss_acc /= n_unsup_batches conf_rate_acc /= n_unsup_batches eval_net.eval() if src_val_iter is not None: src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes, bin_fill_holes) with torch.no_grad(): for batch in src_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): src_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) src_iou = src_iou_eval.score() src_miou = src_iou.mean() else: src_iou_eval = src_iou = src_miou = None tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in tgt_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): tgt_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) tgt_iou = tgt_iou_eval.score() tgt_miou = tgt_iou.mean() t2 = time.time() if ds_src is not ds_tgt: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, ' 'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format( epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, src_miou, tgt_miou)) print('-- SRC {}'.format(', '.join( ['{:.3%}'.format(x) for x in src_iou]))) print('-- TGT {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) else: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}' .format(epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, tgt_miou)) print('-- {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) if save_model: model_path = os.path.join(submit_config.run_dir, "model.pth") torch.save(eval_net, model_path) if save_preds: out_dir = os.path.join(submit_config.run_dir, 'preds') os.makedirs(out_dir, exist_ok=True) with torch.no_grad(): for batch in tgt_val_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) else: out_dir = None if test_loader is not None: test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in test_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): if save_preds: ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) test_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) test_iou = test_iou_eval.score() test_miou = test_iou.mean() print('FINAL TEST: mIoU={:.3%}'.format(test_miou)) print('-- TEST {}'.format(', '.join( ['{:.3%}'.format(x) for x in test_iou])))
def train_toy2d(submit_config: job_helper.SubmitConfig, dataset, region_erode_radius, img_noise_std, n_sup, balance_classes, seed, sup_path, model, n_hidden, hidden_size, hidden_act, norm_layer, perturb_noise_std, dist_contour_range, conf_thresh, conf_avg, cons_weight, cons_loss_fn, cons_no_dropout, learning_rate, teacher_alpha, num_epochs, batch_size, render_cons_grad, render_pred, device, save_output): settings = locals().copy() del settings['submit_config'] import sys print('Command line:') print(' '.join(sys.argv)) print('Settings:') print(', '.join( ['{}={}'.format(k, settings[k]) for k in sorted(settings.keys())])) import os import numpy as np import time import cv2 from scipy.ndimage.morphology import distance_transform_edt import optim_weight_ema from toy2d import generate_data from datapipe.seg_data import RepeatSampler import torch, torch.nn as nn, torch.nn.functional as F import torch.utils.data rng = np.random.RandomState(seed) # Generate/load the dataset if dataset.startswith('img:'): # Generate a dataset from a black and white image image_path = dataset[4:] ds = generate_data.classification_dataset_from_image( image_path, region_erode_radius, img_noise_std, n_sup, balance_classes, rng) image = ds.image elif dataset == 'spiral': # Generate a spiral dataset ds = generate_data.spiral_classification_dataset( n_sup, balance_classes, rng) image = None else: print('Unknown dataset {}, should be spiral or img:<path>'.format( dataset)) return # If a path to a supervised dataset has been provided, load it if sup_path is not None: ds.load_supervised(sup_path) # If we are constraining perturbations to lie along the contours of the distance map to the ground truth class boundary if dist_contour_range > 0.0: if image is None: print( 'Constraining perturbations to lying on distance map contours is only supported for \'image\' experiments' ) return img_1 = image >= 0.5 # Compute signed distance map to boundary dist_1 = distance_transform_edt(img_1) dist_0 = distance_transform_edt(~img_1) dist_map = dist_1 * img_1 + -dist_0 * (~img_1) else: dist_map = None # PyTorch device torch_device = torch.device(device) # Convert perturbation noise std-dev to [y,x] try: perturb_noise_std = np.array( [float(x.strip()) for x in perturb_noise_std.split(',')]) except ValueError: perturb_noise_std = np.array([6.0, 6.0]) # Assume that perturbation noise std-dev is in pixel space (for image experiments), so convert perturb_noise_std_real_scale = perturb_noise_std / ds.img_scale * 2.0 perturb_noise_std_real_scale = torch.tensor(perturb_noise_std_real_scale, dtype=torch.float, device=torch_device) # Define the neural network model (an MLP) class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.drop = nn.Dropout() hidden = [] chn_in = 2 for i in range(n_hidden): if norm_layer == 'spectral_norm': hidden.append( nn.utils.spectral_norm(nn.Linear(chn_in, hidden_size))) elif norm_layer == 'weight_norm': hidden.append( nn.utils.weight_norm(nn.Linear(chn_in, hidden_size))) else: hidden.append(nn.Linear(chn_in, hidden_size)) if norm_layer == 'batch_norm': hidden.append(nn.BatchNorm1d(hidden_size)) elif norm_layer == 'group_norm': hidden.append(nn.GroupNorm(4, hidden_size)) if hidden_act == 'relu': hidden.append(nn.ReLU()) elif hidden_act == 'lrelu': hidden.append(nn.LeakyReLU(0.01)) else: raise ValueError chn_in = hidden_size self.hidden = nn.Sequential(*hidden) # Final layer; 2-class output self.l_final = nn.Linear(chn_in, 2) def forward(self, x, use_dropout=True): x = self.hidden(x) if use_dropout: x = self.drop(x) x = self.l_final(x) return x # Build student network, optimizer and supervised loss criterion student_net = Network().to(torch_device) student_params = list(student_net.parameters()) student_optimizer = torch.optim.Adam(student_params, lr=learning_rate) classification_criterion = nn.CrossEntropyLoss() # Build teacher network and optimizer if model == 'mean_teacher': teacher_net = Network().to(torch_device) teacher_params = list(teacher_net.parameters()) for param in teacher_params: param.requires_grad = False teacher_optimizer = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, ema_alpha=teacher_alpha) pred_net = teacher_net else: teacher_net = None teacher_optimizer = None pred_net = student_net # Robust BCE helper def robust_binary_crossentropy(pred, tgt): inv_tgt = -tgt + 1.0 inv_pred = -pred + 1.0 + 1e-6 return -(tgt * torch.log(pred + 1.0e-6) + inv_tgt * torch.log(inv_pred)) # If we are constraining perturbations to lie on distance map contours, load the distance map as a Torch tensor if dist_contour_range > 0.0: t_dist_map = torch.tensor(dist_map[None, None, ...], dtype=torch.float, device=torch_device) else: t_dist_map = None # Helper function to compute confidence thresholding factor def conf_factor(teacher_pred_prob): # Compute confidence conf_tea = torch.max(teacher_pred_prob, 1)[0] conf_tea = conf_tea.detach() # Compute factor based on threshold and `conf_avg` flag if conf_thresh > 0.0: conf_fac = (conf_tea >= conf_thresh).float() else: conf_fac = torch.ones(conf_tea.shape, dtype=torch.float, device=conf_tea.device) if conf_avg: conf_fac = torch.ones_like(conf_fac) * conf_fac.mean() return conf_fac # Helper function that constrains consistency loss to operate only when perturbations lie along # distance map contours. # When this feature is enabled, it masks to zero the loss for any unsupervised sample whose random perturbation # deviates too far from the distance map contour def dist_map_weighting(t_dist_map, batch_u_X, batch_u_X_1): if t_dist_map is not None and dist_contour_range > 0: # For each sample in `batch_u_X` and `batch_u_X_1`, both of which are # of shape `[n_points, [y,x]]` we want to get the value from the # distance map. For this we use `torch.nn.functional.grid_sample`. # This function expects grid look-up co-ordinates to have # the shape `[batch, height, width, [x, y]]`. # We reshape `batch_u_X` and `batch_u_X_1` to `[1, 1, n_points, [x,y]]` and stack along # the height dimension, making two rows to send to `grid_sample`. # The final shape will be `[1, 2, n_points, [x,y]]`: # 1 sample (1 image) # 2 rows; batch_u_X and batch_u_X_1 # n_points columns # (x,y) # `[n_points, [y,x]]` -> `[1, 1, n_points, [x,y]]` sample_points_0 = torch.cat([ batch_u_X[:, 1].view(1, 1, -1, 1), batch_u_X[:, 0].view( 1, 1, -1, 1) ], dim=3) # `[n_points, [y,x]]` -> `[1, 1, n_points, [x,y]]` sample_points_1 = torch.cat([ batch_u_X_1[:, 1].view(1, 1, -1, 1), batch_u_X_1[:, 0].view( 1, 1, -1, 1) ], dim=3) # -> `[1, 2, n_points, [x,y]]` sample_points = torch.cat([sample_points_0, sample_points_1], dim=1) # Get distance to class boundary from distance map dist_from_boundary = F.grid_sample(t_dist_map, sample_points) # Get the squared difference between the distances from `batch_u_X` to the boundary # and the distances from `batch_u_X_1` to the boundary. delta_dist_sqr = (dist_from_boundary[0, 0, 0, :] - dist_from_boundary[0, 0, 1, :]).pow(2) # Per-sample loss mask based on difference between distances weight = (delta_dist_sqr <= (dist_contour_range * dist_contour_range)).float() return weight else: return torch.ones(len(batch_u_X), dtype=torch.float, device=batch_u_X.device) # Supervised dataset, sampler and loader sup_dataset = torch.utils.data.TensorDataset( torch.tensor(ds.sup_X, dtype=torch.float), torch.tensor(ds.sup_y, dtype=torch.long)) sup_sampler = RepeatSampler(torch.utils.data.RandomSampler(sup_dataset)) sup_sep_loader = torch.utils.data.DataLoader(sup_dataset, batch_size, sampler=sup_sampler, num_workers=1) # Unsupervised dataset, sampler and loader unsup_dataset = torch.utils.data.TensorDataset( torch.tensor(ds.unsup_X, dtype=torch.float)) unsup_sampler = torch.utils.data.RandomSampler(unsup_dataset) unsup_loader = torch.utils.data.DataLoader(unsup_dataset, batch_size, sampler=unsup_sampler, num_workers=1) # Complete dataset and loader all_dataset = torch.utils.data.TensorDataset( torch.tensor(ds.X, dtype=torch.float)) all_loader = torch.utils.data.DataLoader(all_dataset, 16384, shuffle=False, num_workers=1) # Grid points used to render visualizations vis_grid_dataset = torch.utils.data.TensorDataset( torch.tensor(ds.px_grid_vis, dtype=torch.float)) vis_grid_loader = torch.utils.data.DataLoader(vis_grid_dataset, 16384, shuffle=False, num_workers=1) # Evaluation mode initially student_net.eval() if teacher_net is not None: teacher_net.eval() # Compute the magnitude of the gradient of the consistency loss at the logits def consistency_loss_logit_grad_mag(batch_u_X): u_shape = batch_u_X.shape batch_u_X_1 = batch_u_X + torch.randn(u_shape, dtype=torch.float, device=torch_device) * \ perturb_noise_std_real_scale[None, :] student_optimizer.zero_grad() grads = [None] if teacher_net is not None: teacher_unsup_logits = teacher_net(batch_u_X).detach() else: teacher_unsup_logits = student_net(batch_u_X) teacher_unsup_prob = F.softmax(teacher_unsup_logits, dim=1) student_unsup_logits = student_net(batch_u_X_1) def grad_hook(grad): grads[0] = torch.sqrt((grad * grad).sum(dim=1)) student_unsup_logits.register_hook(grad_hook) student_unsup_prob = F.softmax(student_unsup_logits, dim=1) weight = dist_map_weighting(t_dist_map, batch_u_X, batch_u_X_1) mod_fac = conf_factor(teacher_unsup_prob) * weight if cons_loss_fn == 'bce': aug_loss = robust_binary_crossentropy(student_unsup_prob, teacher_unsup_prob) aug_loss = aug_loss.mean(dim=1) * mod_fac unsup_loss = aug_loss.mean() elif cons_loss_fn == 'var': d_aug_loss = student_unsup_prob - teacher_unsup_prob aug_loss = d_aug_loss * d_aug_loss aug_loss = aug_loss.mean(dim=1) * mod_fac unsup_loss = aug_loss.mean() elif cons_loss_fn == 'logits_var': d_aug_loss = student_unsup_logits - teacher_unsup_logits aug_loss = d_aug_loss * d_aug_loss aug_loss = aug_loss.mean(dim=1) * mod_fac unsup_loss = aug_loss.mean() else: raise ValueError unsup_loss.backward() return (grads[0].cpu().numpy(), ) # Helper function for rendering an output image for visualization def render_output_image(): # Generate output for plotting with torch.no_grad(): vis_pred = [] vis_grad = [] if render_cons_grad else None for (batch_X, ) in vis_grid_loader: batch_X = batch_X.to(torch_device) batch_pred_logits = pred_net(batch_X) if render_pred == 'prob': batch_vis = F.softmax(batch_pred_logits, dim=1)[:, 1] elif render_pred == 'class': batch_vis = torch.argmax(batch_pred_logits, dim=1) else: raise ValueError( 'Unknown prediction render {}'.format(render_pred)) batch_vis = batch_vis.detach().cpu().numpy() vis_pred.append(batch_vis) if render_cons_grad: batch_grad = consistency_loss_logit_grad_mag(batch_X) vis_grad.append(batch_grad.detach().cpu().numpy()) vis_pred = np.concatenate(vis_pred, axis=0) if render_cons_grad: vis_grad = np.concatenate(vis_grad, axis=0) out_image = ds.semisup_image_plot(vis_pred, vis_grad) return out_image # Output image for first frame if save_output and submit_config.run_dir is not None: plot_path = os.path.join(submit_config.run_dir, 'epoch_{:05d}.png'.format(0)) cv2.imwrite(plot_path, render_output_image()) else: cv2.imshow('Vis', render_output_image()) k = cv2.waitKey(1) # Train print('|sup|={}'.format(len(ds.sup_X))) print('|unsup|={}'.format(len(ds.unsup_X))) print('|all|={}'.format(len(ds.X))) print('Training...') terminated = False for epoch in range(num_epochs): t1 = time.time() student_net.train() if teacher_net is not None: teacher_net.train() batch_sup_loss_accum = 0.0 batch_conf_mask_sum_accum = 0.0 batch_cons_loss_accum = 0.0 batch_N_accum = 0.0 for sup_batch, unsup_batch in zip(sup_sep_loader, unsup_loader): (batch_X, batch_y) = sup_batch (batch_u_X, ) = unsup_batch batch_X = batch_X.to(torch_device) batch_y = batch_y.to(torch_device) batch_u_X = batch_u_X.to(torch_device) # Apply perturbation to generate `batch_u_X_1` aug_perturbation = torch.randn(batch_u_X.shape, dtype=torch.float, device=torch_device) batch_u_X_1 = batch_u_X + aug_perturbation * perturb_noise_std_real_scale[ None, :] # Supervised loss path student_optimizer.zero_grad() student_sup_logits = student_net(batch_X) sup_loss = classification_criterion(student_sup_logits, batch_y) if cons_weight > 0.0: # Consistency loss path # Logits are computed differently depending on model if model == 'mean_teacher': teacher_unsup_logits = teacher_net( batch_u_X, use_dropout=not cons_no_dropout).detach() student_unsup_logits = student_net( batch_u_X_1, use_dropout=not cons_no_dropout) elif model == 'pi': teacher_unsup_logits = student_net( batch_u_X, use_dropout=not cons_no_dropout) student_unsup_logits = student_net( batch_u_X_1, use_dropout=not cons_no_dropout) elif model == 'pi_onebatch': batch_both = torch.cat([batch_u_X, batch_u_X_1], dim=0) both_unsup_logits = student_net( batch_both, use_dropout=not cons_no_dropout) teacher_unsup_logits = both_unsup_logits[:len(batch_u_X)] student_unsup_logits = both_unsup_logits[len(batch_u_X):] else: raise RuntimeError # Compute predicted probabilities teacher_unsup_prob = F.softmax(teacher_unsup_logits, dim=1) student_unsup_prob = F.softmax(student_unsup_logits, dim=1) # Distance map weighting # (if dist_contour_range is 0 then weight will just be 1) weight = dist_map_weighting(t_dist_map, batch_u_X, batch_u_X_1) # Confidence thresholding conf_fac = conf_factor(teacher_unsup_prob) mod_fac = conf_fac * weight # Compute consistency loss if cons_loss_fn == 'bce': aug_loss = robust_binary_crossentropy( student_unsup_prob, teacher_unsup_prob) aug_loss = aug_loss.mean(dim=1) * mod_fac cons_loss = aug_loss.sum() / weight.sum() elif cons_loss_fn == 'var': d_aug_loss = student_unsup_prob - teacher_unsup_prob aug_loss = d_aug_loss * d_aug_loss aug_loss = aug_loss.mean(dim=1) * mod_fac cons_loss = aug_loss.sum() / weight.sum() elif cons_loss_fn == 'logits_var': d_aug_loss = student_unsup_logits - teacher_unsup_logits aug_loss = d_aug_loss * d_aug_loss aug_loss = aug_loss.mean(dim=1) * mod_fac cons_loss = aug_loss.sum() / weight.sum() else: raise ValueError # Combine supervised and consistency loss loss = sup_loss + cons_loss * cons_weight conf_rate = float(conf_fac.sum()) else: loss = sup_loss conf_rate = 0.0 cons_loss = 0.0 loss.backward() student_optimizer.step() if teacher_optimizer is not None: teacher_optimizer.step() batch_sup_loss_accum += float(sup_loss) batch_conf_mask_sum_accum += conf_rate batch_cons_loss_accum += float(cons_loss) batch_N_accum += len(batch_X) if batch_N_accum > 0: batch_sup_loss_accum /= batch_N_accum batch_conf_mask_sum_accum /= batch_N_accum batch_cons_loss_accum /= batch_N_accum student_net.eval() if teacher_net is not None: teacher_net.eval() # Generate output for plotting if save_output and submit_config.run_dir is not None: plot_path = os.path.join(submit_config.run_dir, 'epoch_{:05d}.png'.format(epoch + 1)) cv2.imwrite(plot_path, render_output_image()) else: cv2.imshow('Vis', render_output_image()) k = cv2.waitKey(1) if (k & 255) == 27: terminated = True break t2 = time.time() # print('Epoch {}: took {:.3f}s: clf loss={:.6f}'.format(epoch, t2-t1, clf_loss)) print( 'Epoch {}: took {:.3f}s: clf loss={:.6f}, conf rate={:.3%}, cons loss={:.6f}' .format(epoch + 1, t2 - t1, batch_sup_loss_accum, batch_conf_mask_sum_accum, batch_cons_loss_accum)) # Get final score based on all samples all_pred_y = [] with torch.no_grad(): for (batch_X, ) in all_loader: batch_X = batch_X.to(torch_device) batch_pred_logits = pred_net(batch_X) batch_pred_cls = torch.argmax(batch_pred_logits, dim=1) all_pred_y.append(batch_pred_cls.detach().cpu().numpy()) all_pred_y = np.concatenate(all_pred_y, axis=0) err_rate = (all_pred_y != ds.y).mean() print( 'FINAL RESULT: Error rate={:.6%} (supervised and unsupervised samples)' .format(err_rate)) if not save_output: # Close output window if not terminated: cv2.waitKey() cv2.destroyAllWindows()
def train_seg_semisup_mask_mt( submit_config: job_helper.SubmitConfig, dataset, model, arch, freeze_bn, opt_type, sgd_momentum, sgd_nesterov, sgd_weight_decay, learning_rate, lr_sched, lr_step_epochs, lr_step_gamma, lr_poly_power, teacher_alpha, bin_fill_holes, crop_size, aug_hflip, aug_vflip, aug_hvflip, aug_scale_hung, aug_max_scale, aug_scale_non_uniform, aug_rot_mag, mask_mode, mask_prop_range, boxmask_n_boxes, boxmask_fixed_aspect_ratio, boxmask_by_size, boxmask_outside_bounds, boxmask_no_invert, cons_loss_fn, cons_weight, conf_thresh, conf_per_pixel, rampup, unsup_batch_ratio, num_epochs, iters_per_epoch, batch_size, n_sup, n_unsup, n_val, split_seed, split_path, val_seed, save_preds, save_model, num_workers): settings = locals().copy() del settings['submit_config'] if ':' in mask_prop_range: a, b = mask_prop_range.split(':') mask_prop_range = (float(a.strip()), float(b.strip())) del a, b else: mask_prop_range = float(mask_prop_range) if mask_mode == 'zero': mask_mix = False elif mask_mode == 'mix': mask_mix = True else: raise ValueError('Unknown mask_mode {}'.format(mask_mode)) del mask_mode import os import math import time import itertools import numpy as np import torch.nn as nn, torch.nn.functional as F from architectures import network_architectures import torch.utils.data from datapipe import datasets from datapipe import seg_data, seg_transforms, seg_transforms_cv import evaluation import optim_weight_ema import mask_gen import lr_schedules if crop_size == '': crop_size = None else: crop_size = [int(x.strip()) for x in crop_size.split(',')] torch_device = torch.device('cuda:0') # # Load data sets # ds_dict = datasets.load_dataset(dataset, n_val, val_seed, n_sup, n_unsup, split_seed, split_path) ds_src = ds_dict['ds_src'] ds_tgt = ds_dict['ds_tgt'] tgt_val_ndx = ds_dict['val_ndx_tgt'] src_val_ndx = ds_dict['val_ndx_src'] if ds_src is not ds_tgt else None test_ndx = ds_dict['test_ndx_tgt'] sup_ndx = ds_dict['sup_ndx'] unsup_ndx = ds_dict['unsup_ndx'] n_classes = ds_src.num_classes root_n_classes = math.sqrt(n_classes) if bin_fill_holes and n_classes != 2: print( 'Binary hole filling can only be used with binary (2-class) segmentation datasets' ) return print('Loaded data') # Build network NetClass = network_architectures.seg.get(arch) student_net = NetClass(ds_src.num_classes).to(torch_device) if opt_type == 'adam': student_optim = torch.optim.Adam([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ]) elif opt_type == 'sgd': student_optim = torch.optim.SGD([ dict(params=student_net.pretrained_parameters(), lr=learning_rate * 0.1), dict(params=student_net.new_parameters(), lr=learning_rate) ], momentum=sgd_momentum, nesterov=sgd_nesterov, weight_decay=sgd_weight_decay) else: raise ValueError('Unknown opt_type {}'.format(opt_type)) if model == 'mean_teacher': teacher_net = NetClass(ds_src.num_classes).to(torch_device) for p in teacher_net.parameters(): p.requires_grad = False teacher_optim = optim_weight_ema.EMAWeightOptimizer( teacher_net, student_net, teacher_alpha) eval_net = teacher_net elif model == 'pi': teacher_net = student_net teacher_optim = None eval_net = student_net else: print('Unknown model type {}'.format(model)) return BLOCK_SIZE = student_net.BLOCK_SIZE NET_MEAN, NET_STD = seg_transforms.get_mean_std(ds_tgt, student_net) if freeze_bn: if not hasattr(student_net, 'freeze_batchnorm'): raise ValueError( 'Network {} does not support batchnorm freezing'.format(arch)) clf_crossent_loss = nn.CrossEntropyLoss(ignore_index=255) print('Built network') mask_generator = mask_gen.BoxMaskGenerator( prop_range=mask_prop_range, n_boxes=boxmask_n_boxes, random_aspect_ratio=not boxmask_fixed_aspect_ratio, prop_by_area=not boxmask_by_size, within_bounds=not boxmask_outside_bounds, invert=not boxmask_no_invert) if iters_per_epoch == -1: iters_per_epoch = len(unsup_ndx) // batch_size total_iters = iters_per_epoch * num_epochs lr_epoch_scheduler, lr_iter_scheduler = lr_schedules.make_lr_schedulers( optimizer=student_optim, total_iters=total_iters, schedule_type=lr_sched, step_epochs=lr_step_epochs, step_gamma=lr_step_gamma, poly_power=lr_poly_power) train_transforms = [] eval_transforms = [] if crop_size is not None: if aug_scale_hung: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropScaleHung( crop_size, (0, 0), uniform_scale=not aug_scale_non_uniform)) elif aug_max_scale != 1.0 or aug_rot_mag != 0.0: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCropRotateScale( crop_size, (0, 0), rot_mag=aug_rot_mag, max_scale=aug_max_scale, uniform_scale=not aug_scale_non_uniform, constrain_rot_scale=True)) else: train_transforms.append( seg_transforms_cv.SegCVTransformRandomCrop(crop_size, (0, 0))) else: if aug_scale_hung: raise NotImplementedError('aug_scale_hung requires a crop_size') if aug_hflip or aug_vflip or aug_hvflip: train_transforms.append( seg_transforms_cv.SegCVTransformRandomFlip(aug_hflip, aug_vflip, aug_hvflip)) train_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) eval_transforms.append( seg_transforms_cv.SegCVTransformNormalizeToTensor(NET_MEAN, NET_STD)) train_sup_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') train_unsup_ds = ds_src.dataset( labels=False, mask=True, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(train_transforms), pipeline_type='cv') eval_ds = ds_src.dataset( labels=True, mask=False, xf=False, pair=False, transforms=seg_transforms.SegTransformCompose(eval_transforms), pipeline_type='cv') add_mask_params_to_batch = mask_gen.AddMaskParamsToBatch(mask_generator) collate_fn = seg_data.SegCollate(BLOCK_SIZE) mask_collate_fn = seg_data.SegCollate( BLOCK_SIZE, batch_aug_fn=add_mask_params_to_batch) # Train data pipeline: data loaders sup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(sup_ndx)) train_sup_loader = torch.utils.data.DataLoader(train_sup_ds, batch_size, sampler=sup_sampler, collate_fn=collate_fn, num_workers=num_workers) if cons_weight > 0.0: unsup_sampler = seg_data.RepeatSampler( torch.utils.data.SubsetRandomSampler(unsup_ndx)) train_unsup_loader_0 = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=mask_collate_fn, num_workers=num_workers) if mask_mix: train_unsup_loader_1 = torch.utils.data.DataLoader( train_unsup_ds, batch_size, sampler=unsup_sampler, collate_fn=collate_fn, num_workers=num_workers) else: train_unsup_loader_1 = None else: train_unsup_loader_0 = None train_unsup_loader_1 = None # Eval pipeline src_val_loader, tgt_val_loader, test_loader = datasets.eval_data_pipeline( ds_src, ds_tgt, src_val_ndx, tgt_val_ndx, test_ndx, batch_size, collate_fn, NET_MEAN, NET_STD, num_workers) # Report setttings print('Settings:') print(', '.join([ '{}={}'.format(key, settings[key]) for key in sorted(list(settings.keys())) ])) # Report dataset size print('Dataset:') print('len(sup_ndx)={}'.format(len(sup_ndx))) print('len(unsup_ndx)={}'.format(len(unsup_ndx))) if ds_src is not ds_tgt: print('len(src_val_ndx)={}'.format(len(tgt_val_ndx))) print('len(tgt_val_ndx)={}'.format(len(tgt_val_ndx))) else: print('len(val_ndx)={}'.format(len(tgt_val_ndx))) if test_ndx is not None: print('len(test_ndx)={}'.format(len(test_ndx))) if n_sup != -1: print('sup_ndx={}'.format(sup_ndx.tolist())) # Track mIoU for early stopping best_tgt_miou = None best_epoch = 0 eval_net_state = { key: value.detach().cpu().numpy() for key, value in eval_net.state_dict().items() } # Create iterators train_sup_iter = iter(train_sup_loader) train_unsup_iter_0 = iter( train_unsup_loader_0) if train_unsup_loader_0 is not None else None train_unsup_iter_1 = iter( train_unsup_loader_1) if train_unsup_loader_1 is not None else None iter_i = 0 print('Training...') for epoch_i in range(num_epochs): if lr_epoch_scheduler is not None: lr_epoch_scheduler.step(epoch_i) t1 = time.time() if rampup > 0: ramp_val = network_architectures.sigmoid_rampup(epoch_i, rampup) else: ramp_val = 1.0 student_net.train() if teacher_net is not student_net: teacher_net.train() if freeze_bn: student_net.freeze_batchnorm() if teacher_net is not student_net: teacher_net.freeze_batchnorm() sup_loss_acc = 0.0 consistency_loss_acc = 0.0 conf_rate_acc = 0.0 n_sup_batches = 0 n_unsup_batches = 0 src_val_iter = iter( src_val_loader) if src_val_loader is not None else None tgt_val_iter = iter( tgt_val_loader) if tgt_val_loader is not None else None for sup_batch in itertools.islice(train_sup_iter, iters_per_epoch): if lr_iter_scheduler is not None: lr_iter_scheduler.step(iter_i) student_optim.zero_grad() # # Supervised branch # batch_x = sup_batch['image'].to(torch_device) batch_y = sup_batch['labels'].to(torch_device) logits_sup = student_net(batch_x) sup_loss = clf_crossent_loss(logits_sup, batch_y[:, 0, :, :]) sup_loss.backward() if cons_weight > 0.0: for _ in range(unsup_batch_ratio): # # Unsupervised branch # if mask_mix: # Mix mode: batch consists of paired unsupervised samples and mask parameters unsup_batch0 = next(train_unsup_iter_0) unsup_batch1 = next(train_unsup_iter_1) batch_ux0 = unsup_batch0['image'].to(torch_device) batch_um0 = unsup_batch0['mask'].to(torch_device) batch_ux1 = unsup_batch1['image'].to(torch_device) batch_um1 = unsup_batch1['mask'].to(torch_device) batch_mask_params = unsup_batch0['mask_params'].to( torch_device) # batch_um0 and batch_um1 are masks that are 1 for valid pixels, 0 for invalid pixels. # They are used later on to scale the consistency loss, so that consistency loss is # only computed for valid pixels. # Explanation: # When using geometric augmentations such as rotations, some pixels in the training # crop may come from outside the bounds of the input image. These pixels will have a value # of 0 in these masks. Similarly, when using scaled crops, the size of the crop # from the input image that must be scaled to the size of the training crop may be # larger than one/both of the input image dimensions. Pixels in the training crop # that arise from outside the input image bounds will once again be given a value # of 0 in these masks. # Convert mask parameters to masks of shape (N,1,H,W) batch_mix_masks = mask_generator.torch_masks_from_params( batch_mask_params, batch_ux0.shape[2:4], torch_device) # Mix images with masks batch_ux_mixed = batch_ux0 * ( 1 - batch_mix_masks) + batch_ux1 * batch_mix_masks batch_um_mixed = batch_um0 * ( 1 - batch_mix_masks) + batch_um1 * batch_mix_masks # Get teacher predictions for original images with torch.no_grad(): logits_u0_tea = teacher_net(batch_ux0).detach() logits_u1_tea = teacher_net(batch_ux1).detach() # Get student prediction for mixed image logits_cons_stu = student_net(batch_ux_mixed) # Mix teacher predictions using same mask # It makes no difference whether we do this with logits or probabilities as # the mask pixels are either 1 or 0 logits_cons_tea = logits_u0_tea * ( 1 - batch_mix_masks) + logits_u1_tea * batch_mix_masks # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) loss_mask = batch_um_mixed else: # Cut mode: batch consists of unsupervised samples and mask params unsup_batch = next(train_unsup_iter_0) batch_ux = unsup_batch['image'].to(torch_device) batch_um = unsup_batch['mask'].to(torch_device) batch_mask_params = unsup_batch['mask_params'].to( torch_device) # Convert mask parameters to masks of shape (N,1,H,W) batch_cut_masks = mask_generator.torch_masks_from_params( batch_mask_params, batch_ux.shape[2:4], torch_device) # Cut image with mask (mask regions to zero) batch_ux_cut = batch_ux * batch_cut_masks # Get teacher predictions for original image with torch.no_grad(): logits_cons_tea = teacher_net(batch_ux).detach() # Get student prediction for cut image logits_cons_stu = student_net(batch_ux_cut) # Logits -> probs prob_cons_tea = F.softmax(logits_cons_tea, dim=1) prob_cons_stu = F.softmax(logits_cons_stu, dim=1) loss_mask = batch_cut_masks * batch_um # -- shared by mix and cut -- # Confidence thresholding if conf_thresh > 0.0: # Compute confidence of teacher predictions conf_tea = prob_cons_tea.max(dim=1)[0] # Compute confidence mask conf_mask = (conf_tea >= conf_thresh).float()[:, None, :, :] # Record rate for reporting conf_rate_acc += float(conf_mask.mean()) # Average confidence mask if requested if not conf_per_pixel: conf_mask = conf_mask.mean() loss_mask = loss_mask * conf_mask elif rampup > 0: conf_rate_acc += ramp_val # Compute per-pixel consistency loss # Note that the way we aggregate the loss across the class/channel dimension (1) # depends on the loss function used. Generally, summing over the class dimension # keeps the magnitude of the gradient of the loss w.r.t. the logits # nearly constant w.r.t. the number of classes. When using logit-variance, # dividing by `sqrt(num_classes)` helps. if cons_loss_fn == 'var': delta_prob = prob_cons_stu - prob_cons_tea consistency_loss = delta_prob * delta_prob consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'logits_var': delta_logits = logits_cons_stu - logits_cons_tea consistency_loss = delta_logits * delta_logits consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'logits_smoothl1': consistency_loss = F.smooth_l1_loss(logits_cons_stu, logits_cons_tea, reduce=False) consistency_loss = consistency_loss.sum( dim=1, keepdim=True) / root_n_classes elif cons_loss_fn == 'bce': consistency_loss = network_architectures.robust_binary_crossentropy( prob_cons_stu, prob_cons_tea) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) elif cons_loss_fn == 'kld': consistency_loss = F.kl_div(F.log_softmax( logits_cons_stu, dim=1), prob_cons_tea, reduce=False) consistency_loss = consistency_loss.sum(dim=1, keepdim=True) else: raise ValueError( 'Unknown consistency loss function {}'.format( cons_loss_fn)) # Apply consistency loss mask and take the mean over pixels and images consistency_loss = (consistency_loss * loss_mask).mean() # Modulate with rampup if desired if rampup > 0: consistency_loss = consistency_loss * ramp_val # Weight the consistency loss and back-prop unsup_loss = consistency_loss * cons_weight unsup_loss.backward() consistency_loss_acc += float(consistency_loss.detach()) n_unsup_batches += 1 student_optim.step() if teacher_optim is not None: teacher_optim.step() sup_loss_val = float(sup_loss.detach()) if np.isnan(sup_loss_val): print('NaN detected; network dead, bailing.') return sup_loss_acc += sup_loss_val n_sup_batches += 1 iter_i += 1 sup_loss_acc /= n_sup_batches if n_unsup_batches > 0: consistency_loss_acc /= n_unsup_batches conf_rate_acc /= n_unsup_batches eval_net.eval() if ds_src is not ds_tgt: src_iou_eval = evaluation.EvaluatorIoU(ds_src.num_classes, bin_fill_holes) with torch.no_grad(): for batch in src_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): src_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) src_iou = src_iou_eval.score() src_miou = src_iou.mean() else: src_iou_eval = src_iou = src_miou = None tgt_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in tgt_val_iter: batch_x = batch['image'].to(torch_device) batch_y = batch['labels'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i in range(len(batch_y)): tgt_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) tgt_iou = tgt_iou_eval.score() tgt_miou = tgt_iou.mean() t2 = time.time() if ds_src is not ds_tgt: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, ' 'SRC VAL mIoU={:.3%}, TGT VAL mIoU={:.3%}'.format( epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, src_miou, tgt_miou)) print('-- SRC {}'.format(', '.join( ['{:.3%}'.format(x) for x in src_iou]))) print('-- TGT {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) else: print( 'Epoch {}: took {:.3f}s, TRAIN clf loss={:.6f}, consistency loss={:.6f}, conf rate={:.3%}, VAL mIoU={:.3%}' .format(epoch_i + 1, t2 - t1, sup_loss_acc, consistency_loss_acc, conf_rate_acc, tgt_miou)) print('-- {}'.format(', '.join( ['{:.3%}'.format(x) for x in tgt_iou]))) if save_model: model_path = os.path.join(submit_config.run_dir, "model.pth") torch.save(eval_net, model_path) if save_preds: out_dir = os.path.join(submit_config.run_dir, 'preds') os.makedirs(out_dir, exist_ok=True) with torch.no_grad(): for batch in tgt_val_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) else: out_dir = None if test_loader is not None: test_iou_eval = evaluation.EvaluatorIoU(ds_tgt.num_classes, bin_fill_holes) with torch.no_grad(): for batch in test_loader: batch_x = batch['image'].to(torch_device) batch_ndx = batch['index'].numpy() logits = eval_net(batch_x) pred_y = torch.argmax(logits, dim=1).detach().cpu().numpy() for sample_i, sample_ndx in enumerate(batch_ndx): if save_preds: ds_tgt.save_prediction_by_index( out_dir, pred_y[sample_i].astype(np.uint32), sample_ndx) test_iou_eval.sample(batch_y[sample_i, 0], pred_y[sample_i], ignore_value=255) test_iou = test_iou_eval.score() test_miou = test_iou.mean() print('FINAL TEST: mIoU={:.3%}'.format(test_miou)) print('-- TEST {}'.format(', '.join( ['{:.3%}'.format(x) for x in test_iou])))