def main(dataset='cifar10', data_path='/tmp/data', output_dir='/tmp/fixmatch', run_id=None, seed=1, block_depth=4, num_filters=32, num_labeled=40, sample_mode='label_dist_min1', num_epochs=1024, batches_per_epoch=1024, labeled_batch_size=64, unlabeled_batch_size=64 * 7, unlabeled_weight=1., lr=0.03, momentum=0.9, nesterov=True, weight_decay=5e-4, bn_momentum=1e-3, exp_moving_avg_decay=1e-3, threshold=0.95, labeled_aug='weak', unlabeled_aug=('weak', 'strong'), dist_alignment=False, dist_alignment_batches=128, dist_alignment_eps=1e-6, checkpoint_interval=1024, max_checkpoints=25, num_workers=4, mixed_precision=True, devices=('cuda:0', )): """FixMatch training. Args: dataset: the dataset to use ('cifar10', 'cifar100', 'svhn') data_path: dataset root directory output_dir: directory to save logs and model checkpoints run_id: name for training run (output will be saved under output_dir/run_id) seed: random seed block_depth: WideResNet block depth num_filters: WideResNet base filter count num_labeled: number of labeled examples sample_mode: labeled dataset sampling mode ('equal', 'label_dist', 'label_dist_min1', 'multinomial', 'multinomial_min1') num_epochs: number of training epochs batches_per_epoch: number of batches per epoch labeled_batch_size: number of labeled examples per batch unlabeled_batch_size: number of unlabeled examples per batch (total batch size will be labeled_batch_size + 2 * unlabeled_batch_size) unlabeled_weight: weight of unlabeled loss term lr: SGD initial learning rate momentum: SGD momentum parameter nesterov: whether to use SGD with Nesterov acceleration weight_decay: weight decay parameter bn_momentum: batch normalization momentum parameter exp_moving_avg_decay: model parameter exponential moving average decay threshold: confidence threshold labeled_aug: data augmentation mode for labeled examples ('none', 'weak', 'strong', 'weak_noflip', 'strong_noflip'). 'strong' augmentation uses RandAugment. 'noflip' disables horizontal flip augmentation. unlabeled_aug: pair of augmentations for unlabeled examples dist_alignment: whether to apply distribution alignment heuristic dist_alignment_batches: number of batches used to compute moving average of label distribution dist_alignment_eps: smoothing parameter for estimating label distribution checkpoint_interval: number of batches between checkpoints max_checkpoints: maximum number of checkpoints to retain num_workers: number of workers per data loader mixed_precision: whether to use mixed precision training devices: list of devices for data parallel training """ # initial setup num_batches = num_epochs * batches_per_epoch np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) args = dict(locals()) logger.info(pprint.pformat(args)) run_id = datetime.datetime.now().isoformat() if run_id is None else run_id output_dir = os.path.join(output_dir, str(run_id)) logger.info('output dir = %s' % output_dir) if not os.path.isdir(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) train_logger, eval_logger = TableLogger(), TableLogger() # load datasets if dataset == 'cifar10': dataset_fn = get_cifar10 elif dataset == 'cifar100': dataset_fn = get_cifar100 elif dataset == 'svhn': dataset_fn = get_svhn else: raise ValueError('Invalid dataset ' + dataset) datasets = dataset_fn(data_path, num_labeled, labeled_aug=labeled_aug, unlabeled_aug=unlabeled_aug, sample_mode=sample_mode, whiten=True) model = modules.WideResNet(num_classes=datasets['labeled'].num_classes, bn_momentum=bn_momentum, block_depth=block_depth, channels=num_filters) optimizer = partial(torch.optim.SGD, lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay) scheduler = partial(utils.WarmupCosineLrScheduler, warmup_iter=0, max_iter=num_batches) evaluator = ModelEvaluator(datasets['test'], labeled_batch_size + unlabeled_batch_size, num_workers) param_avg_ctor = partial(modules.EMA, alpha=exp_moving_avg_decay) def evaluate(model, avg_model, iter): results = evaluator.evaluate(model, device=devices[0]) avg_results = evaluator.evaluate(avg_model, device=devices[0]) valid_stats = { 'valid_loss': avg_results.log_loss, 'valid_accuracy': avg_results.accuracy, 'valid_loss_noavg': results.log_loss, 'valid_accuracy_noavg': results.accuracy } eval_logger.write(iter=iter, **valid_stats) eval_logger.step() return avg_results.accuracy def checkpoint(model, avg_model, optimizer, scheduler, iter, fmt='ckpt-{:08d}.pt'): path = os.path.join(output_dir, fmt.format(iter)) torch.save( dict(iter=iter, model=model.state_dict(), avg_model=avg_model.state_dict(), optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict()), path) checkpoint_files = sorted( list( filter(lambda x: re.match(r'^ckpt-[0-9]+.pt$', x), os.listdir(output_dir)))) if len(checkpoint_files) > max_checkpoints: os.remove(os.path.join(output_dir, checkpoint_files[0])) train_logger.to_dataframe().to_pickle( os.path.join(output_dir, 'train.log.pkl')) eval_logger.to_dataframe().to_pickle( os.path.join(output_dir, 'eval.log.pkl')) trainer = FixMatch(num_iters=num_epochs * batches_per_epoch, num_workers=num_workers, model_optimizer_ctor=optimizer, lr_scheduler_ctor=scheduler, param_avg_ctor=param_avg_ctor, labeled_batch_size=labeled_batch_size, unlabeled_batch_size=unlabeled_batch_size, unlabeled_weight=unlabeled_weight, threshold=threshold, dist_alignment=dist_alignment, dist_alignment_batches=dist_alignment_batches, dist_alignment_eps=dist_alignment_eps, mixed_precision=mixed_precision, devices=devices) timer = utils.Timer() with tqdm(desc='train', total=num_batches, position=0) as train_pbar: train_iter = utils.Generator( trainer.train_iter(model, datasets['labeled'], datasets['unlabeled'])) eval_acc = None # training loop for i, stats in enumerate(train_iter): train_pbar.set_postfix(loss=stats.loss, eval_acc=eval_acc, refresh=False) train_pbar.update() train_logger.write(loss=stats.loss, loss_labeled=stats.loss_labeled, loss_unlabeled=stats.loss_unlabeled, threshold_frac=stats.threshold_frac, time=timer()) if (checkpoint_interval is not None and i > 0 and (i+1) % checkpoint_interval == 0) or \ (i == num_batches - 1): checkpoint(stats.model, stats.avg_model, stats.optimizer, stats.scheduler, i + 1) eval_acc = evaluate(stats.model, stats.avg_model, i + 1) logger.info('eval acc = %.4f | allocated frac = %.4f' % (eval_acc, stats.threshold_frac)) train_logger.step()
def main( num_workers=8, num_filters=32, dataset='cifar10', data_path='/tmp/data', output_dir='/tmp/sla', run_id=None, num_labeled=40, seed=1, num_epochs=1024, batches_per_epoch=1024, checkpoint_interval=1024, snapshot_interval=None, max_checkpoints=25, optimizer='sgd', lr=0.03, momentum=0.9, nesterov=True, weight_decay=5e-4, bn_momentum=1e-3, labeled_batch_size=64, unlabeled_batch_size=64*7, unlabeled_weight=1., exp_moving_avg_decay=1e-3, allocation_schedule=((0., 1.), (0., 1.)), entropy_reg=100., update_tol=0.01, labeled_aug='weak', unlabeled_aug=('weak', 'strong'), whiten=True, sample_mode='label_dist_min1', upper_bound_method='empirical', upper_bound_kwargs={}, mixed_precision=True, devices=('cuda:0',)): # initial setup num_batches = num_epochs * batches_per_epoch random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) args = dict(locals()) logger.info(pprint.pformat(args)) run_id = datetime.datetime.now().isoformat() if run_id is None else run_id output_dir = os.path.join(output_dir, str(run_id)) logger.info('output dir = %s' % output_dir) if not os.path.isdir(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) train_logger, eval_logger = TableLogger(), TableLogger() # load datasets if dataset == 'cifar10': dataset_fn = get_cifar10 elif dataset == 'cifar100': dataset_fn = get_cifar100 elif dataset == 'svhn': dataset_fn = get_svhn else: raise ValueError('Invalid dataset ' + dataset) datasets = dataset_fn( data_path, num_labeled, labeled_aug=labeled_aug, unlabeled_aug=unlabeled_aug, sample_mode=sample_mode, whiten=whiten) model = modules.WideResNet( num_classes=datasets['labeled'].num_classes, bn_momentum=bn_momentum, channels=num_filters) optimizer = partial(torch.optim.SGD, lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay) scheduler = partial(utils.WarmupCosineLrScheduler, warmup_iter=0, max_iter=num_batches) evaluator = ModelEvaluator(datasets['test'], labeled_batch_size + unlabeled_batch_size, num_workers) def evaluate(model, avg_model, iter): results = evaluator.evaluate(model, device=devices[0]) avg_results = evaluator.evaluate(avg_model, device=devices[0]) valid_stats = { 'valid_loss': avg_results.log_loss, 'valid_accuracy': avg_results.accuracy, 'valid_loss_noavg': results.log_loss, 'valid_accuracy_noavg': results.accuracy } eval_logger.write( iter=iter, **valid_stats) eval_logger.step() return avg_results.accuracy def checkpoint(model, avg_model, optimizer, scheduler, iter, fmt='ckpt-{:08d}.pt'): path = os.path.join(output_dir, fmt.format(iter)) torch.save(dict( iter=iter, model=model.state_dict(), avg_model=avg_model.state_dict(), optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict()), path) checkpoint_files = sorted(list(filter(lambda x: re.match(r'^ckpt-[0-9]+.pt$', x), os.listdir(output_dir)))) if len(checkpoint_files) > max_checkpoints: os.remove(os.path.join(output_dir, checkpoint_files[0])) train_logger.to_dataframe().to_pickle(os.path.join(output_dir, 'train.log.pkl')) eval_logger.to_dataframe().to_pickle(os.path.join(output_dir, 'eval.log.pkl')) trainer = SLASelfTraining( num_epochs=num_epochs, batches_per_epoch=batches_per_epoch, num_workers=num_workers, model_optimizer_ctor=optimizer, lr_scheduler_ctor=scheduler, param_avg_ctor=partial(modules.EMA, alpha=exp_moving_avg_decay), labeled_batch_size=labeled_batch_size, unlabeled_batch_size=unlabeled_batch_size, unlabeled_weight=unlabeled_weight, allocation_schedule=utils.PiecewiseLinear(*allocation_schedule), entropy_reg=entropy_reg, update_tol=update_tol, upper_bound_method=upper_bound_method, upper_bound_kwargs=upper_bound_kwargs, mixed_precision=mixed_precision, devices=devices) timer = utils.Timer() with tqdm(desc='train', total=num_batches, position=0) as train_pbar: train_iter = utils.Generator( trainer.train_iter(model, datasets['labeled'].num_classes, datasets['labeled'], datasets['unlabeled'])) smoothed_loss = utils.ema(0.3, avg_only=True) smoothed_loss.send(None) smoothed_acc = utils.ema(1., avg_only=False) smoothed_acc.send(None) eval_stats = None, None # training loop for i, stats in enumerate(train_iter): if isinstance(stats, trainer.__class__.Stats): train_pbar.set_postfix( loss=smoothed_loss.send(stats.loss), eval_acc=eval_stats[0], eval_v=eval_stats[1], refresh=False) train_pbar.update() train_logger.write( loss=stats.loss, loss_labeled=stats.loss_labeled, loss_unlabeled=stats.loss_unlabeled, mean_imputed_labels=stats.label_vars.data.mean(0).cpu().numpy(), scaling_vars=stats.scaling_vars.data.cpu().numpy(), allocation_param=stats.allocation_param, assigned_frac=stats.label_vars.data.sum(-1).mean(), assignment_err=stats.assgn_err, assignment_iters=stats.assgn_iters, time=timer()) if (checkpoint_interval is not None and i > 0 and (i + 1) % checkpoint_interval == 0) or (i == num_batches - 1): eval_acc = evaluate(stats.model, stats.avg_model, i+1) eval_stats = smoothed_acc.send(eval_acc) checkpoint(stats.model, stats.avg_model, stats.optimizer, stats.scheduler, i+1) logger.info('eval acc = %.4f | allocated frac = %.4f | allocation param = %.4f' % (eval_acc, stats.label_vars.mean(0).sum().cpu().item(), stats.allocation_param)) logger.info('assignment err = %.4e | assignment iters = %d' % (stats.assgn_err, stats.assgn_iters)) logger.info('batch assignments = {}'.format(stats.label_vars.mean(0).cpu().numpy())) logger.info('scaling vars = {}'.format(stats.scaling_vars.cpu().numpy())) # take snapshots that are guaranteed to be preserved if snapshot_interval is not None and i > 0 and (i + 1) % snapshot_interval == 0: checkpoint(stats.model, stats.avg_model, stats.optimizer, stats.scheduler, i + 1, 'snapshot-{:08d}.pt') train_logger.step()