def get_logger(work_dir, cfg): logger = DistSummaryWriter(work_dir) config_txt = os.path.join(work_dir, 'cfg.txt') if is_main_process(): with open(config_txt, 'w') as fp: fp.write(str(cfg)) return logger
def main(): time_stamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') global best_prec1, args best_prec1 = 0 args = parse() if not len(args.data): raise Exception("error: No data set provided") args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 # make apex optional if args.opt_level is not None or args.sync_bn: try: global DDP, amp, optimizers, parallel from apex.parallel import DistributedDataParallel as DDP from apex import amp, optimizers, parallel except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to run this example." ) if args.opt_level is None and args.distributed: from torch.nn.parallel import DistributedDataParallel as DDP dist_print("opt_level = {}".format(args.opt_level)) dist_print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) dist_print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) dist_print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) torch.backends.cudnn.benchmark = True best_prec1 = 0 if args.deterministic: # cudnn.benchmark = False # cudnn.deterministic = True # torch.manual_seed(args.local_rank) torch.set_printoptions(precision=10) setup_seed(0) args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.total_batch_size = args.world_size * args.batch_size assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." args.work_dir = os.path.join(args.work_dir, time_stamp + args.arch + args.note) if not args.evaluate: if args.local_rank == 0: os.makedirs(args.work_dir) logger = DistSummaryWriter(args.work_dir) # create model if args.pretrained: dist_print("=> using pre-trained model '{}'".format(args.arch)) if args.arch == 'fcanet34': model = fcanet34(pretrained=True) elif args.arch == 'fcanet50': model = fcanet50(pretrained=True) elif args.arch == 'fcanet101': model = fcanet101(pretrained=True) elif args.arch == 'fcanet152': model = fcanet152(pretrained=True) else: model = models.__dict__[args.arch](pretrained=True) else: dist_print("=> creating model '{}'".format(args.arch)) if args.arch == 'fcanet34': model = fcanet34() elif args.arch == 'fcanet50': model = fcanet50() elif args.arch == 'fcanet101': model = fcanet101() elif args.arch == 'fcanet152': model = fcanet152() else: model = models.__dict__[args.arch]() if args.sync_bn: dist_print("using apex synced BN") model = parallel.convert_syncbn_model(model) if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'): if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format model = model.cuda().to(memory_format=memory_format) else: model = model.cuda() # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size * args.world_size) / 256. optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. if args.opt_level is not None: model, optimizer = amp.initialize( model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. if args.opt_level is not None: model = DDP(model, delay_allreduce=True) else: model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) # Optionally resume from a checkpoint if args.resume: # Use a local scope to avoid dangling references def resume(): if os.path.isfile(args.resume): dist_print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu)) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) dist_print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: dist_print("=> no checkpoint found at '{}'".format( args.resume)) resume() if args.evaluate: assert args.evaluate_model is not None dist_print("=> loading checkpoint '{}' for eval".format( args.evaluate_model)) checkpoint = torch.load( args.evaluate_model, map_location=lambda storage, loc: storage.cuda(args.gpu)) if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) else: state_dict_with_module = {} for k, v in checkpoint.items(): state_dict_with_module['module.' + k] = v model.load_state_dict(state_dict_with_module) # Data loading code if len(args.data) == 1: traindir = os.path.join(args.data[0], 'train') valdir = os.path.join(args.data[0], 'val') else: traindir = args.data[0] valdir = args.data[1] if (args.arch == "inception_v3"): raise RuntimeError( "Currently, inception_v3 is not supported by this example.") # crop_size = 299 # val_size = 320 # I chose this value arbitrarily, we can adjust. else: crop_size = 224 val_size = 256 pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=traindir, crop=crop_size, dali_cpu=args.dali_cpu, shard_id=args.local_rank, num_shards=args.world_size) pipe.build() train_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False) pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=valdir, crop=crop_size, size=val_size, shard_id=args.local_rank, num_shards=args.world_size) pipe.build() val_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False) # criterion = nn.CrossEntropyLoss().cuda() criterion = CrossEntropyLabelSmooth().cuda() if args.evaluate: validate(val_loader, model, criterion) return len_epoch = int(math.ceil(train_loader._size / args.batch_size)) T_max = 95 * len_epoch warmup_iters = 5 * len_epoch scheduler = CosineAnnealingLR(optimizer, T_max, warmup='linear', warmup_iters=warmup_iters) total_time = AverageMeter() for epoch in range(args.start_epoch, args.epochs): # train for one epoch avg_train_time = train(train_loader, model, criterion, optimizer, epoch, logger, scheduler) total_time.update(avg_train_time) torch.cuda.empty_cache() # evaluate on validation set [prec1, prec5] = validate(val_loader, model, criterion) logger.add_scalar('Val/prec1', prec1, global_step=epoch) logger.add_scalar('Val/prec5', prec5, global_step=epoch) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, work_dir=args.work_dir) if epoch == args.epochs - 1: dist_print('##Best Top-1 {0}\n' '##Perf {2}'.format( best_prec1, args.total_batch_size / total_time.avg)) with open(os.path.join(args.work_dir, 'res.txt'), 'w') as f: f.write('arhc: {0} \n best_prec1 {1}'.format( args.arch + args.note, best_prec1)) train_loader.reset() val_loader.reset()