def init_pipe(training): # -- make data transforms transform, init_transform = make_transforms( dataset_name=dataset_name, subset_path=subset_path, unlabeled_frac=unlabeled_frac if training else 0., training=training, split_seed=split_seed, basic_augmentations=True, force_center_crop=True, normalize=normalize) # -- init data-loaders/samplers (data_loader, data_sampler) = init_data(dataset_name=dataset_name, transform=transform, init_transform=init_transform, u_batch_size=None, s_batch_size=16, stratify=False, classes_per_batch=None, world_size=1, rank=0, root_path=root_path, image_folder=image_folder, training=training, copy_data=False, drop_last=False) return transform, init_transform, data_loader, data_sampler
def main(args): # ----------------------------------------------------------------------- # # PASSED IN PARAMS FROM CONFIG FILE # ----------------------------------------------------------------------- # # -- META model_name = args['meta']['model_name'] output_dim = args['meta']['output_dim'] load_model = args['meta']['load_checkpoint'] r_file = args['meta']['read_checkpoint'] copy_data = args['meta']['copy_data'] use_fp16 = args['meta']['use_fp16'] use_pred_head = args['meta']['use_pred_head'] device = torch.device(args['meta']['device']) torch.cuda.set_device(device) # -- CRITERTION reg = args['criterion']['me_max'] supervised_views = args['criterion']['supervised_views'] classes_per_batch = args['criterion']['classes_per_batch'] s_batch_size = args['criterion']['supervised_imgs_per_class'] u_batch_size = args['criterion']['unsupervised_batch_size'] temperature = args['criterion']['temperature'] sharpen = args['criterion']['sharpen'] # -- DATA unlabeled_frac = args['data']['unlabeled_frac'] color_jitter = args['data']['color_jitter_strength'] normalize = args['data']['normalize'] root_path = args['data']['root_path'] image_folder = args['data']['image_folder'] dataset_name = args['data']['dataset'] subset_path = args['data']['subset_path'] unique_classes = args['data']['unique_classes_per_rank'] multicrop = args['data']['multicrop'] label_smoothing = args['data']['label_smoothing'] data_seed = None if 'cifar10' in dataset_name: data_seed = args['data']['data_seed'] crop_scale = (0.75, 1.0) if multicrop > 0 else (0.5, 1.0) mc_scale = (0.3, 0.75) mc_size = 18 else: crop_scale = (0.14, 1.0) if multicrop > 0 else (0.08, 1.0) mc_scale = (0.05, 0.14) mc_size = 96 # -- OPTIMIZATION wd = float(args['optimization']['weight_decay']) num_epochs = args['optimization']['epochs'] warmup = args['optimization']['warmup'] start_lr = args['optimization']['start_lr'] lr = args['optimization']['lr'] final_lr = args['optimization']['final_lr'] mom = args['optimization']['momentum'] nesterov = args['optimization']['nesterov'] # -- LOGGING folder = args['logging']['folder'] tag = args['logging']['write_tag'] # ----------------------------------------------------------------------- # # -- init torch distributed backend world_size, rank = init_distributed() logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') # -- log/checkpointing paths log_file = os.path.join(folder, f'{tag}_r{rank}.csv') save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar') latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') best_path = os.path.join(folder, f'{tag}' + '-best.pth.tar') load_path = None if load_model: load_path = os.path.join(folder, r_file) if r_file is not None else latest_path # -- make csv_logger csv_logger = CSVLogger(log_file, ('%d', 'epoch'), ('%d', 'itr'), ('%.5f', 'paws-xent-loss'), ('%.5f', 'paws-me_max-reg'), ('%d', 'time (ms)')) # -- init model encoder = init_model(device=device, model_name=model_name, use_pred=use_pred_head, output_dim=output_dim) if world_size > 1: process_group = apex.parallel.create_syncbn_process_group(0) encoder = apex.parallel.convert_syncbn_model( encoder, process_group=process_group) # -- init losses paws = init_paws_loss(multicrop=multicrop, tau=temperature, T=sharpen, me_max=reg) # -- assume support images are sampled with ClassStratifiedSampler labels_matrix = make_labels_matrix(num_classes=classes_per_batch, s_batch_size=s_batch_size, world_size=world_size, device=device, unique_classes=unique_classes, smoothing=label_smoothing) # -- make data transforms transform, init_transform = make_transforms(dataset_name=dataset_name, subset_path=subset_path, unlabeled_frac=unlabeled_frac, training=True, split_seed=data_seed, crop_scale=crop_scale, basic_augmentations=False, color_jitter=color_jitter, normalize=normalize) multicrop_transform = (multicrop, None) if multicrop > 0: multicrop_transform = make_multicrop_transform( dataset_name=dataset_name, num_crops=multicrop, size=mc_size, crop_scale=mc_scale, normalize=normalize, color_distortion=color_jitter) # -- init data-loaders/samplers (unsupervised_loader, unsupervised_sampler, supervised_loader, supervised_sampler) = init_data(dataset_name=dataset_name, transform=transform, init_transform=init_transform, supervised_views=supervised_views, u_batch_size=u_batch_size, s_batch_size=s_batch_size, unique_classes=unique_classes, classes_per_batch=classes_per_batch, multicrop_transform=multicrop_transform, world_size=world_size, rank=rank, root_path=root_path, image_folder=image_folder, training=True, copy_data=copy_data) iter_supervised = None ipe = len(unsupervised_loader) logger.info(f'iterations per epoch: {ipe}') # -- init optimizer and scheduler scaler = torch.cuda.amp.GradScaler(enabled=use_fp16) encoder, optimizer, scheduler = init_opt(encoder=encoder, weight_decay=wd, start_lr=start_lr, ref_lr=lr, final_lr=final_lr, ref_mom=mom, nesterov=nesterov, iterations_per_epoch=ipe, warmup=warmup, num_epochs=num_epochs) if world_size > 1: encoder = DistributedDataParallel(encoder, broadcast_buffers=False) start_epoch = 0 # -- load training checkpoint if load_model: encoder, optimizer, start_epoch = load_checkpoint(r_path=load_path, encoder=encoder, opt=optimizer, scaler=scaler, use_fp16=use_fp16) for _ in range(start_epoch): for _ in range(ipe): scheduler.step() # -- TRAINING LOOP best_loss = None for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d' % (epoch + 1)) # -- update distributed-data-loader epoch unsupervised_sampler.set_epoch(epoch) if supervised_sampler is not None: supervised_sampler.set_epoch(epoch) loss_meter = AverageMeter() ploss_meter = AverageMeter() rloss_meter = AverageMeter() time_meter = AverageMeter() data_meter = AverageMeter() for itr, udata in enumerate(unsupervised_loader): if use_fp16: udata = [u.half() for u in udata] def load_imgs(): # -- unsupervised imgs (2 views) imgs = [u.to(device) for u in udata[:2]] # -- unsupervised multicrop img views mc_imgs = None if multicrop > 0: mc_imgs = torch.cat([u.to(device) for u in udata[2:-1]], dim=0) # -- labeled support imgs global iter_supervised try: sdata = next(iter_supervised) except Exception: iter_supervised = iter(supervised_loader) logger.info( f'len.supervised_loader: {len(iter_supervised)}') sdata = next(iter_supervised) finally: if use_fp16: sdata = [s.half() for s in sdata] simgs = [s.to(device) for s in sdata[:-1]] labels = torch.cat( [labels_matrix for _ in range(supervised_views)]) # -- concatenate unlabeled images and labeled support images imgs = torch.cat(imgs + simgs, dim=0) return imgs, mc_imgs, labels (imgs, mc_imgs, labels), dtime = gpu_timer(load_imgs) data_meter.update(dtime) def train_step(): with torch.cuda.amp.autocast(enabled=use_fp16): optimizer.zero_grad() (h, z), z_mc = encoder(imgs, mc_imgs, return_before_head=True) # Compute paws loss in full precision with torch.cuda.amp.autocast(enabled=False): # Step 1. convert representations to fp32 h, z = h.float(), z.float() if z_mc is not None: z_mc = z_mc.float() # Step 2. determine anchor views/supports and their # corresponding target views/supports if not use_pred_head: h = z target_supports = h[2 * u_batch_size:].detach() target_views = h[:2 * u_batch_size].detach() target_views = torch.cat([ target_views[u_batch_size:], target_views[:u_batch_size] ], dim=0) # -- anchor_supports = z[2 * u_batch_size:] anchor_views = z[:2 * u_batch_size] if multicrop > 0: anchor_views = torch.cat([anchor_views, z_mc], dim=0) # Step 3. compute paws loss with me-max regularization (ploss, me_max) = paws(anchor_views=anchor_views, anchor_supports=anchor_supports, anchor_support_labels=labels, target_views=target_views, target_supports=target_supports, target_support_labels=labels) loss = ploss + me_max scaler.scale(loss).backward() lr_stats = scaler.step(optimizer) scaler.update() scheduler.step() return (float(loss), float(ploss), float(me_max), lr_stats) (loss, ploss, rloss, lr_stats), etime = gpu_timer(train_step) loss_meter.update(loss) ploss_meter.update(ploss) rloss_meter.update(rloss) time_meter.update(etime) if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): csv_logger.log(epoch + 1, itr, ploss_meter.avg, rloss_meter.avg, time_meter.avg) logger.info('[%d, %5d] loss: %.3f (%.3f %.3f) ' '(%d ms; %d ms)' % (epoch + 1, itr, loss_meter.avg, ploss_meter.avg, rloss_meter.avg, time_meter.avg, data_meter.avg)) if lr_stats is not None: logger.info('[%d, %5d] lr_stats: %.3f (%.2e, %.2e)' % (epoch + 1, itr, lr_stats.avg, lr_stats.min, lr_stats.max)) assert not np.isnan(loss), 'loss is nan' # -- logging/checkpointing logger.info('avg. loss %.3f' % loss_meter.avg) if rank == 0: save_dict = { 'encoder': encoder.state_dict(), 'opt': optimizer.state_dict(), 'epoch': epoch + 1, 'unlabel_prob': unlabeled_frac, 'loss': loss_meter.avg, 's_batch_size': s_batch_size, 'u_batch_size': u_batch_size, 'world_size': world_size, 'lr': lr, 'temperature': temperature, 'amp': scaler.state_dict() } torch.save(save_dict, latest_path) if best_loss is None or best_loss > loss_meter.avg: best_loss = loss_meter.avg logger.info('updating "best" checkpoint') torch.save(save_dict, best_path) if (epoch + 1) % checkpoint_freq == 0 \ or (epoch + 1) % 10 == 0 and epoch < checkpoint_freq: torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))
def main(args): # -- META model_name = args['meta']['model_name'] port = args['meta']['master_port'] load_checkpoint = args['meta']['load_checkpoint'] training = args['meta']['training'] copy_data = args['meta']['copy_data'] use_fp16 = args['meta']['use_fp16'] device = torch.device(args['meta']['device']) torch.cuda.set_device(device) # -- DATA unlabeled_frac = args['data']['unlabeled_frac'] normalize = args['data']['normalize'] root_path = args['data']['root_path'] image_folder = args['data']['image_folder'] dataset_name = args['data']['dataset'] subset_path = args['data']['subset_path'] num_classes = args['data']['num_classes'] data_seed = None if 'cifar10' in dataset_name: data_seed = args['data']['data_seed'] crop_scale = (0.5, 1.0) if 'cifar10' in dataset_name else (0.08, 1.0) # -- OPTIMIZATION wd = float(args['optimization']['weight_decay']) ref_lr = args['optimization']['lr'] use_lars = args['optimization']['use_lars'] zero_init = args['optimization']['zero_init'] num_epochs = args['optimization']['epochs'] # -- LOGGING folder = args['logging']['folder'] tag = args['logging']['write_tag'] r_file_enc = args['logging']['pretrain_path'] # -- log/checkpointing paths r_enc_path = os.path.join(folder, r_file_enc) w_enc_path = os.path.join(folder, f'{tag}-fine-tune.pth.tar') # -- init distributed world_size, rank = init_distributed(port) logger.info(f'initialized rank/world-size: {rank}/{world_size}') # -- optimization/evaluation params if training: batch_size = 256 else: batch_size = 16 unlabeled_frac = 0.0 load_checkpoint = True num_epochs = 1 # -- init loss criterion = torch.nn.CrossEntropyLoss() # -- make train data transforms and data loaders/samples transform, init_transform = make_transforms(dataset_name=dataset_name, subset_path=subset_path, unlabeled_frac=unlabeled_frac, training=training, crop_scale=crop_scale, split_seed=data_seed, basic_augmentations=True, normalize=normalize) (data_loader, dist_sampler) = init_data(dataset_name=dataset_name, transform=transform, init_transform=init_transform, u_batch_size=None, s_batch_size=batch_size, classes_per_batch=None, world_size=world_size, rank=rank, root_path=root_path, image_folder=image_folder, training=training, copy_data=copy_data) ipe = len(data_loader) logger.info(f'initialized data-loader (ipe {ipe})') # -- make val data transforms and data loaders/samples val_transform, val_init_transform = make_transforms( dataset_name=dataset_name, subset_path=subset_path, unlabeled_frac=-1, training=True, basic_augmentations=True, force_center_crop=True, normalize=normalize) (val_data_loader, val_dist_sampler) = init_data(dataset_name=dataset_name, transform=val_transform, init_transform=val_init_transform, u_batch_size=None, s_batch_size=batch_size, classes_per_batch=None, world_size=1, rank=0, root_path=root_path, image_folder=image_folder, training=True, copy_data=copy_data) logger.info(f'initialized val data-loader (ipe {len(val_data_loader)})') # -- init model and optimizer scaler = torch.cuda.amp.GradScaler(enabled=use_fp16) encoder, optimizer, scheduler = init_model( device=device, device_str=args['meta']['device'], num_classes=num_classes, training=training, use_fp16=use_fp16, r_enc_path=r_enc_path, iterations_per_epoch=ipe, world_size=world_size, ref_lr=ref_lr, weight_decay=wd, use_lars=use_lars, zero_init=zero_init, num_epochs=num_epochs, model_name=model_name) best_acc = None start_epoch = 0 # -- load checkpoint if not training or load_checkpoint: encoder, optimizer, scheduler, start_epoch, best_acc = load_from_path( r_path=w_enc_path, encoder=encoder, opt=optimizer, sched=scheduler, scaler=scaler, device_str=args['meta']['device'], use_fp16=use_fp16) if not training: logger.info('putting model in eval mode') encoder.eval() logger.info( sum(p.numel() for n, p in encoder.named_parameters() if p.requires_grad and ('fc' not in n))) start_epoch = 0 for epoch in range(start_epoch, num_epochs): def train_step(): # -- update distributed-data-loader epoch dist_sampler.set_epoch(epoch) top1_correct, top5_correct, total = 0, 0, 0 for i, data in enumerate(data_loader): with torch.cuda.amp.autocast(enabled=use_fp16): inputs, labels = data[0].to(device), data[1].to(device) outputs = encoder(inputs) loss = criterion(outputs, labels) total += inputs.shape[0] top5_correct += float( outputs.topk(5, dim=1).indices.eq(labels.unsqueeze(1)).sum()) top1_correct += float( outputs.max(dim=1).indices.eq(labels).sum()) top1_acc = 100. * top1_correct / total top5_acc = 100. * top5_correct / total if training: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad() if i % log_freq == 0: logger.info('[%d, %5d] %.3f%% %.3f%% (loss: %.3f)' % (epoch + 1, i, top1_acc, top5_acc, loss)) return 100. * top1_correct / total def val_step(): val_encoder = copy.deepcopy(encoder).eval() top1_correct, total = 0, 0 for i, data in enumerate(val_data_loader): inputs, labels = data[0].to(device), data[1].to(device) outputs = val_encoder(inputs) total += inputs.shape[0] top1_correct += float( outputs.max(dim=1).indices.eq(labels).sum()) top1_acc = 100. * top1_correct / total logger.info('[%d, %5d] %.3f%%' % (epoch + 1, i, top1_acc)) return 100. * top1_correct / total train_top1 = 0. train_top1 = train_step() with torch.no_grad(): val_top1 = val_step() log_str = 'train:' if training else 'test:' logger.info('[%d] (%s: %.3f%%) (val: %.3f%%)' % (epoch + 1, log_str, train_top1, val_top1)) # -- logging/checkpointing if training and (rank == 0) and ((best_acc is None) or (best_acc < val_top1)): best_acc = val_top1 save_dict = { 'encoder': encoder.state_dict(), 'opt': optimizer.state_dict(), 'sched': scheduler.state_dict(), 'epoch': epoch + 1, 'unlabel_prob': unlabeled_frac, 'world_size': world_size, 'best_top1_acc': best_acc, 'batch_size': batch_size, 'lr': ref_lr, 'amp': scaler.state_dict() } torch.save(save_dict, w_enc_path) return train_top1, val_top1
def main(args): # -- META model_name = args['meta']['model_name'] load_checkpoint = args['meta']['load_checkpoint'] copy_data = args['meta']['copy_data'] output_dim = args['meta']['output_dim'] use_pred_head = args['meta']['use_pred_head'] use_fp16 = args['meta']['use_fp16'] device = torch.device(args['meta']['device']) torch.cuda.set_device(device) # -- DATA unlabeled_frac = args['data']['unlabeled_frac'] label_smoothing = args['data']['label_smoothing'] normalize = args['data']['normalize'] root_path = args['data']['root_path'] image_folder = args['data']['image_folder'] dataset_name = args['data']['dataset'] subset_path = args['data']['subset_path'] unique_classes = args['data']['unique_classes_per_rank'] data_seed = args['data']['data_seed'] # -- CRITERTION classes_per_batch = args['criterion']['classes_per_batch'] supervised_views = args['criterion']['supervised_views'] batch_size = args['criterion']['supervised_batch_size'] temperature = args['criterion']['temperature'] # -- OPTIMIZATION wd = float(args['optimization']['weight_decay']) num_epochs = args['optimization']['epochs'] use_lars = args['optimization']['use_lars'] warmup = args['optimization']['warmup'] start_lr = args['optimization']['start_lr'] ref_lr = args['optimization']['lr'] final_lr = args['optimization']['final_lr'] momentum = args['optimization']['momentum'] nesterov = args['optimization']['nesterov'] # -- LOGGING folder = args['logging']['folder'] tag = args['logging']['write_tag'] r_file_enc = args['logging']['pretrain_path'] # -- log/checkpointing paths r_enc_path = os.path.join(folder, r_file_enc) w_enc_path = os.path.join(folder, f'{tag}-fine-tune-SNN.pth.tar') # -- init distributed world_size, rank = init_distributed() logger.info(f'initialized rank/world-size: {rank}/{world_size}') # -- init loss suncet = init_suncet_loss(num_classes=classes_per_batch, batch_size=batch_size * supervised_views, world_size=world_size, rank=rank, temperature=temperature, device=device) labels_matrix = make_labels_matrix(num_classes=classes_per_batch, s_batch_size=batch_size, world_size=world_size, device=device, unique_classes=unique_classes, smoothing=label_smoothing) # -- make data transforms transform, init_transform = make_transforms(dataset_name=dataset_name, subset_path=subset_path, unlabeled_frac=unlabeled_frac, training=True, split_seed=data_seed, basic_augmentations=True, normalize=normalize) (data_loader, dist_sampler) = init_data(dataset_name=dataset_name, transform=transform, init_transform=init_transform, supervised_views=supervised_views, u_batch_size=None, stratify=True, s_batch_size=batch_size, classes_per_batch=classes_per_batch, unique_classes=unique_classes, world_size=world_size, rank=rank, root_path=root_path, image_folder=image_folder, training=True, copy_data=copy_data) # -- rough estimate of labeled imgs per class used to set the number of # fine-tuning iterations imgs_per_class = int( 1300 * (1. - unlabeled_frac)) if 'imagenet' in dataset_name else int( 5000 * (1. - unlabeled_frac)) dist_sampler.set_inner_epochs(imgs_per_class // batch_size) ipe = len(data_loader) logger.info(f'initialized data-loader (ipe {ipe})') # -- init model and optimizer scaler = torch.cuda.amp.GradScaler(enabled=use_fp16) encoder, optimizer, scheduler = init_model(device=device, training=True, r_enc_path=r_enc_path, iterations_per_epoch=ipe, world_size=world_size, start_lr=start_lr, ref_lr=ref_lr, num_epochs=num_epochs, output_dim=output_dim, model_name=model_name, warmup_epochs=warmup, use_pred_head=use_pred_head, use_fp16=use_fp16, wd=wd, final_lr=final_lr, momentum=momentum, nesterov=nesterov, use_lars=use_lars) best_acc, val_top1 = None, None start_epoch = 0 # -- load checkpoint if load_checkpoint: encoder, optimizer, scaler, scheduler, start_epoch, best_acc = load_from_path( r_path=w_enc_path, encoder=encoder, opt=optimizer, scaler=scaler, sched=scheduler, device=device, use_fp16=use_fp16, ckp=True) for epoch in range(start_epoch, num_epochs): def train_step(): # -- update distributed-data-loader epoch dist_sampler.set_epoch(epoch) for i, data in enumerate(data_loader): imgs = torch.cat([s.to(device) for s in data[:-1]], 0) labels = torch.cat( [labels_matrix for _ in range(supervised_views)]) with torch.cuda.amp.autocast(enabled=use_fp16): optimizer.zero_grad() z = encoder(imgs) loss = suncet(z, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() if i % log_freq == 0: logger.info('[%d, %5d] (loss: %.3f)' % (epoch + 1, i, loss)) with torch.no_grad(): with nostdout(): val_top1, _ = val_run(pretrained=copy.deepcopy(encoder), subset_path=subset_path, unlabeled_frac=unlabeled_frac, dataset_name=dataset_name, root_path=root_path, image_folder=image_folder, use_pred=use_pred_head, normalize=normalize, split_seed=data_seed) logger.info('[%d] (val: %.3f%%)' % (epoch + 1, val_top1)) train_step() # -- logging/checkpointing if (rank == 0) and ((best_acc is None) or (best_acc < val_top1)): best_acc = val_top1 save_dict = { 'encoder': encoder.state_dict(), 'opt': optimizer.state_dict(), 'sched': scheduler.state_dict(), 'epoch': epoch + 1, 'unlabel_prob': unlabeled_frac, 'world_size': world_size, 'batch_size': batch_size, 'best_top1_acc': best_acc, 'lr': ref_lr, 'amp': scaler.state_dict() } torch.save(save_dict, w_enc_path) logger.info('[%d] (best-val: %.3f%%)' % (epoch + 1, best_acc))