def main(): global args args = parser.parse_args() init_distributed_mode(args) fix_random_seeds(args.seed) logger, training_stats = initialize_exp(args, "epoch", "loss") # build data train_dataset = MultiCropDataset( args.data_path, args.size_crops, args.nmb_crops, args.min_scale_crops, args.max_scale_crops, return_index=True, ) sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, drop_last=True) logger.info("Building data done with {} images loaded.".format( len(train_dataset))) # build model model = resnet_models.__dict__[args.arch]( normalize=True, hidden_mlp=args.hidden_mlp, output_dim=args.feat_dim, nmb_prototypes=args.nmb_prototypes, ) # synchronize batch norm layers if args.sync_bn == "pytorch": model = nn.SyncBatchNorm.convert_sync_batchnorm(model) elif args.sync_bn == "apex": process_group = None if args.world_size // 8 > 0: process_group = apex.parallel.create_syncbn_process_group( args.world_size // 8) model = apex.parallel.convert_syncbn_model(model, process_group=process_group) # copy model to GPU model = model.cuda() if args.rank == 0: logger.info(model) logger.info("Building model done.") # build optimizer # base_lr=4.8 wd=1e-6 optimizer = torch.optim.SGD( model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=args.wd, ) # Using Dist LARC Optimizer optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) # LR Scheduling warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs) iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs)) cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \ math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters]) lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) logger.info("Building optimizer done.") # wrap model model = nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu_to_work_on], find_unused_parameters=True, ) # optionally resume from a checkpoint to_restore = {"epoch": 0} restart_from_checkpoint( os.path.join(args.dump_path, "checkpoint.pth.tar"), run_variables=to_restore, state_dict=model, optimizer=optimizer, ) start_epoch = to_restore["epoch"] # build the memory bank mb_path = os.path.join(args.dump_path, "mb" + str(args.rank) + ".pth") if os.path.isfile(mb_path): mb_ckp = torch.load(mb_path) local_memory_index = mb_ckp["local_memory_index"] local_memory_embeddings = mb_ckp["local_memory_embeddings"] else: local_memory_index, local_memory_embeddings = init_memory( train_loader, model) cudnn.benchmark = True for epoch in range(start_epoch, args.epochs): # train the network for one epoch logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler train_loader.sampler.set_epoch(epoch) # train the network scores, local_memory_index, local_memory_embeddings = train( train_loader, model, optimizer, epoch, lr_schedule, local_memory_index, local_memory_embeddings, ) training_stats.update(scores) # save checkpoints if args.rank == 0: save_dict = { "epoch": epoch + 1, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } torch.save( save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar"), ) if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1: shutil.copyfile( os.path.join(args.dump_path, "checkpoint.pth.tar"), os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"), ) torch.save( { "local_memory_embeddings": local_memory_embeddings, "local_memory_index": local_memory_index }, mb_path)
def main(): global args args = parser.parse_args() init_distributed_mode(args) fix_random_seeds(args.seed) logger, training_stats = initialize_exp(args, "epoch", "loss") # build data train_dataset = MultiCropDataset( args.data_path, args.size_crops, args.nmb_crops, args.min_scale_crops, args.max_scale_crops, pil_blur=args.use_pil_blur, ) sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, drop_last=True) logger.info("Building data done with {} images loaded.".format( len(train_dataset))) # build model model = resnet_models.__dict__[args.arch]( normalize=True, hidden_mlp=args.hidden_mlp, output_dim=args.feat_dim, nmb_prototypes=args.nmb_prototypes, ) # synchronize batch norm layers if args.sync_bn == "pytorch": model = nn.SyncBatchNorm.convert_sync_batchnorm(model) elif args.sync_bn == "apex": process_group = None if args.world_size // 8 > 0: process_group = apex.parallel.create_syncbn_process_group( args.world_size // 8) model = apex.parallel.convert_syncbn_model(model, process_group=process_group) # copy model to GPU model = model.cuda() if args.rank == 0: logger.info(model) logger.info("Building model done.") # build optimizer optimizer = torch.optim.SGD( model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=args.wd, ) optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs) iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs)) cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \ math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters]) lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) logger.info("Building optimizer done.") # init mixed precision if args.use_fp16: model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1") logger.info("Initializing mixed precision done.") # wrap model model = nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu_to_work_on], find_unused_parameters=True, ) # optionally resume from a checkpoint to_restore = {"epoch": 0} restart_from_checkpoint( os.path.join(args.dump_path, "checkpoint.pth.tar"), run_variables=to_restore, state_dict=model, optimizer=optimizer, amp=apex.amp, ) start_epoch = to_restore["epoch"] # build the queue queue = None queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth") if os.path.isfile(queue_path): queue = torch.load(queue_path)["queue"] # the queue needs to be divisible by the batch size # args.queue_length -= args.queue_length % (args.batch_size * args.world_size) cudnn.benchmark = True ## initialize queue print('start initialize queue') queue = init_queue(train_loader, model, args) print('queue initialize finish') for epoch in range(start_epoch, args.epochs): # train the network for one epoch logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler train_loader.sampler.set_epoch(epoch) # optionally starts a queue # queue shape : (Ncrops, Lqueue, feat) --> (NClass, NCrops, Lqueue, feat) # if queue is None: # queue = torch.randn(1000, args.feat_dim).cuda() # queue = nn.functional.normalize(queue, dim=1, p=2) # train the network scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue, args) training_stats.update(scores) # save checkpoints if args.rank == 0: save_dict = { "epoch": epoch + 1, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } if args.use_fp16: save_dict["amp"] = apex.amp.state_dict() torch.save( save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar"), ) if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1: shutil.copyfile( os.path.join(args.dump_path, "checkpoint.pth.tar"), os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"), ) if queue is not None: torch.save({"queue": queue}, queue_path)
def main(): global args args = parser.parse_args() init_distributed_mode(args) fix_random_seeds(args.seed) logger, training_stats = initialize_exp(args, "epoch", "loss") # build data if args.type==0: traindir = os.path.join(args.data_path, 'train') train_dataset = MultiCropDataset( traindir, args.size_crops, args.nmb_crops, args.min_scale_crops, args.max_scale_crops, ) else: from src.inside_crop import inside_crop, TwoCropsTransform from src.multicropdataset import CropDataset traindir = os.path.join(args.data_path, 'train') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) cur_transform = inside_crop(args.size_crops, args.nmb_crops, args.min_scale_crops, args.max_scale_crops, normalize) train_dataset = CropDataset(traindir, cur_transform) sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, drop_last=True ) #configure dataset for knn checking traindir = os.path.join(args.data_path, 'train') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) testdir = os.path.join(args.data_path, 'val') transform_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) # val_dataset = datasets.ImageFolder(traindir,transform_test) val_dataset = imagenet(traindir, 0.2, transform_test) test_dataset = datasets.ImageFolder(testdir, transform_test) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.knn_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.knn_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=test_sampler, drop_last=False) logger.info("Building data done with {} images loaded.".format(len(train_dataset))) # build model model = resnet_models.__dict__[args.arch]( normalize=True, hidden_mlp=args.hidden_mlp, output_dim=args.feat_dim, nmb_prototypes=args.nmb_prototypes, ) # synchronize batch norm layers if args.sync_bn == "pytorch": model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # elif args.sync_bn == "apex": # process_group = None # if args.world_size // 8 > 0: # process_group = apex.parallel.create_syncbn_process_group(args.world_size // 8) # model = apex.parallel.convert_syncbn_model(model, process_group=process_group) # copy model to GPU model = model.cuda() if args.rank == 0: logger.info(model) logger.info("Building model done.") # build optimizer optimizer = torch.optim.SGD( model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=args.wd, ) #optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) optimizer = SGD_LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8) warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs) iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs)) cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \ math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters]) lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) logger.info("Building optimizer done.") # init mixed precision # if args.use_fp16: # model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1") # logger.info("Initializing mixed precision done.") # wrap model model = nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu_to_work_on], find_unused_parameters=True, ) # optionally resume from a checkpoint to_restore = {"epoch": 0} restart_from_checkpoint( os.path.join(args.dump_path, "checkpoint.pth.tar"), run_variables=to_restore, state_dict=model, optimizer=optimizer, #amp=apex.amp, ) start_epoch = to_restore["epoch"] # build the queue queue = None queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth") if os.path.isfile(queue_path): queue = torch.load(queue_path)["queue"] # the queue needs to be divisible by the batch size args.queue_length -= args.queue_length % (args.batch_size * args.world_size) cudnn.benchmark = True for epoch in range(start_epoch, args.epochs): # train the network for one epoch logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler train_loader.sampler.set_epoch(epoch) # optionally starts a queue if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None: queue = torch.zeros( len(args.crops_for_assign), args.queue_length // args.world_size, args.feat_dim, ).cuda() # train the network scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue) training_stats.update(scores) if epoch%args.knn_freq==0: print("gpu consuming before cleaning:", torch.cuda.memory_allocated()/1024/1024) torch.cuda.empty_cache() print("gpu consuming after cleaning:", torch.cuda.memory_allocated()/1024/1024) #try: #should also work using a much smaller knn batch size with sampler knn_test_acc=knn_monitor(model, val_loader, test_loader, global_k=min(args.knn_neighbor,len(val_loader.dataset))) #except: # torch.cuda.empty_cache() #knn_test_acc = knn_monitor_fast(model.module.encoder_q, val_loader, test_loader, # global_k=min(args.knn_neighbor, len(val_loader.dataset))) print({'*KNN monitor Accuracy': knn_test_acc}) torch.cuda.empty_cache() # save checkpoints if args.rank == 0: save_dict = { "epoch": epoch + 1, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } #if args.use_fp16: # save_dict["amp"] = apex.amp.state_dict() torch.save( save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar"), ) if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1: shutil.copyfile( os.path.join(args.dump_path, "checkpoint.pth.tar"), os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"), ) if queue is not None: torch.save({"queue": queue}, queue_path)