def main(): global args, best_prec1 args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) for key in config: for k, v in config[key].items(): setattr(args, k, v) print('Enabled distributed training.') rank, world_size = init_dist(backend='nccl', port=args.port) args.rank = rank args.world_size = world_size # create model print("=> creating model '{}'".format(args.model)) if 'resnetv1sn' in args.model: model = models.__dict__[args.model]( using_moving_average=args.using_moving_average, using_bn=args.using_bn, last_gamma=args.last_gamma) else: model = models.__dict__[args.model]( using_moving_average=args.using_moving_average, using_bn=args.using_bn) model.cuda() broadcast_params(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) # auto resume from a checkpoint model_dir = args.model_dir start_epoch = 0 if args.rank == 0 and not os.path.exists(model_dir): os.makedirs(model_dir) if args.evaluate: load_state_ckpt(args.checkpoint_path, model) else: best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer) if args.rank == 0: writer = SummaryWriter(model_dir) else: writer = None cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ColorAugmentation(), normalize, ])) val_dataset = ImagenetDataset( args.val_root, args.val_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset) train_loader = DataLoader(train_dataset, batch_size=args.batch_size // args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) val_loader = DataLoader(val_dataset, batch_size=args.batch_size // args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler) if args.evaluate: validate(val_loader, model, criterion, 0, writer) return niters = len(train_loader) lr_scheduler = LRScheduler(optimizer, niters, args) for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) # train for one epoch train(train_loader, model, criterion, optimizer, lr_scheduler, epoch, writer) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, writer) if rank == 0: # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( model_dir, { 'epoch': epoch + 1, 'model': args.model, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best)
def main(): global args, best_prec1 args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) for key in config: for k, v in config[key].items(): setattr(args, k, v) print('Enabled distributed training.') rank, world_size = init_dist( backend='nccl', port=args.port) args.rank = rank args.world_size = world_size np.random.seed(args.seed*args.rank) torch.manual_seed(args.seed*args.rank) torch.cuda.manual_seed(args.seed*args.rank) torch.cuda.manual_seed_all(args.seed*args.rank) # create model print("=> creating model '{}'".format(args.model)) if args.SinglePath: architecture = 20*[0] channels_scales = 20*[1.0] #load derived child network log_alpha = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))['state_dict']['log_alpha'] weights = torch.zeros_like(log_alpha).scatter_(1, torch.argmax(log_alpha, dim = -1).view(-1,1), 1) model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales, weights=weights) model.cuda() broadcast_params(model) for v in model.parameters(): if v.requires_grad: if v.grad is None: v.grad = torch.zeros_like(v) model.log_alpha.grad = torch.zeros_like(model.log_alpha) if not args.retrain: load_state_ckpt(args.checkpoint_path, model) checkpoint = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device())) args.base_lr = checkpoint['optimizer']['param_groups'][0]['lr'] if args.reset_bn_stat: model._reset_bn_running_stats() # define loss function (criterion) and optimizer criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda() wo_wd_params = [] wo_wd_param_names = [] network_params = [] network_param_names = [] for name, mod in model.named_modules(): #if isinstance(mod, (nn.BatchNorm2d, SwitchNorm2d)): if isinstance(mod, nn.BatchNorm2d): for key, value in mod.named_parameters(): wo_wd_param_names.append(name+'.'+key) for key, value in model.named_parameters(): if key != 'log_alpha': if value.requires_grad: if key in wo_wd_param_names: wo_wd_params.append(value) else: network_params.append(value) network_param_names.append(key) params = [ {'params': network_params, 'lr': args.base_lr, 'weight_decay': args.weight_decay }, {'params': wo_wd_params, 'lr': args.base_lr, 'weight_decay': 0.}, ] param_names = [network_param_names, wo_wd_param_names] if args.rank == 0: print('>>> params w/o weight decay: ', wo_wd_param_names) optimizer = torch.optim.SGD(params, momentum=args.momentum) arch_optimizer=None # auto resume from a checkpoint remark = 'imagenet_' remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(float("{0:.2f}".format(args.base_ lr))) + '_seed_' + str(args.seed) if args.remark != 'none': remark += '_'+args.remark args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark) args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark) generate_date = str(datetime.now().date()) path = os.path.join(generate_date, args.save) if args.rank == 0: log_format = '%(asctime)s %(message)s' utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py')) logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(path, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", args) writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log) else: writer = None #model_dir = args.model_dir model_dir = path start_epoch = 0 if args.evaluate: load_state_ckpt(args.checkpoint_path, model) else: best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_wo_ms = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) val_dataset = ImagenetDataset( args.val_root, args.val_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset) train_loader = DataLoader( train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) train_loader_wo_ms = DataLoader( train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) val_loader = DataLoader( val_dataset, batch_size=50, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler) if args.evaluate: validate(val_loader, model, criterion, 0, writer, logging) return niters = len(train_loader) lr_scheduler = LRScheduler(optimizer, niters, args) for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) if args.rank == 0 and args.SinglePath: logging.info('epoch %d', epoch) # evaluate on validation set after loading the model if epoch == 0 and not args.reset_bn_stat: prec1 = validate(val_loader, model, criterion, epoch, writer, logging) # train for one epoch if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms and args.retrain: train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging) else: train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, writer, logging) if rank == 0: # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint(model_dir, { 'epoch': epoch + 1, 'model': args.model, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best)
def main(): global args, best_prec1 args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) for key in config: for k, v in config[key].items(): setattr(args, k, v) print('Enabled distributed training.') rank, world_size = init_dist( backend='nccl', port=args.port) args.rank = rank args.world_size = world_size np.random.seed(args.seed*args.rank) torch.manual_seed(args.seed*args.rank) torch.cuda.manual_seed(args.seed*args.rank) torch.cuda.manual_seed_all(args.seed*args.rank) print('random seed: ', args.seed*args.rank) # create model print("=> creating model '{}'".format(args.model)) if args.SinglePath: architecture = 20*[0] channels_scales = 20*[1.0] model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales) model.cuda() broadcast_params(model) for v in model.parameters(): if v.requires_grad: if v.grad is None: v.grad = torch.zeros_like(v) model.log_alpha.grad = torch.zeros_like(model.log_alpha) criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda() wo_wd_params = [] wo_wd_param_names = [] network_params = [] network_param_names = [] for name, mod in model.named_modules(): if isinstance(mod, nn.BatchNorm2d): for key, value in mod.named_parameters(): wo_wd_param_names.append(name+'.'+key) for key, value in model.named_parameters(): if key != 'log_alpha': if value.requires_grad: if key in wo_wd_param_names: wo_wd_params.append(value) else: network_params.append(value) network_param_names.append(key) params = [ {'params': network_params, 'lr': args.base_lr, 'weight_decay': args.weight_decay }, {'params': wo_wd_params, 'lr': args.base_lr, 'weight_decay': 0.}, ] param_names = [network_param_names, wo_wd_param_names] if args.rank == 0: print('>>> params w/o weight decay: ', wo_wd_param_names) optimizer = torch.optim.SGD(params, momentum=args.momentum) if args.SinglePath: arch_optimizer = torch.optim.Adam( [param for name, param in model.named_parameters() if name == 'log_alpha'], lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay ) # auto resume from a checkpoint remark = 'imagenet_' remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(args.base_lr) + '_seed_' + str(args.seed) + '_pretrain_' + str(args.pretrain_epoch) if args.early_fix_arch: remark += '_early_fix_arch' if args.flops_loss: remark += '_flops_loss_' + str(args.flops_loss_coef) if args.remark != 'none': remark += '_'+args.remark args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark) args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark) generate_date = str(datetime.now().date()) path = os.path.join(generate_date, args.save) if args.rank == 0: log_format = '%(asctime)s %(message)s' utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py')) logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(path, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", args) writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log) else: writer = None model_dir = path start_epoch = 0 if args.evaluate: load_state_ckpt(args.checkpoint_path, model) else: best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer) cudnn.benchmark = True cudnn.enabled = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_wo_ms = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) val_dataset = ImagenetDataset( args.val_root, args.val_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset) train_loader = DataLoader( train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) train_loader_wo_ms = DataLoader( train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) val_loader = DataLoader( val_dataset, batch_size=50, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler) if args.evaluate: validate(val_loader, model, criterion, 0, writer, logging) return niters = len(train_loader) lr_scheduler = LRScheduler(optimizer, niters, args) for epoch in range(start_epoch, 85): train_sampler.set_epoch(epoch) if args.early_fix_arch: if len(model.fix_arch_index.keys()) > 0: for key, value_lst in model.fix_arch_index.items(): model.log_alpha.data[key, :] = value_lst[1] sort_log_alpha = torch.topk(F.softmax(model.log_alpha.data, dim=-1), 2) argmax_index = (sort_log_alpha[0][:,0] - sort_log_alpha[0][:,1] >= 0.3) for id in range(argmax_index.size(0)): if argmax_index[id] == 1 and id not in model.fix_arch_index.keys(): model.fix_arch_index[id] = [sort_log_alpha[1][id,0].item(), model.log_alpha.detach().clone()[id, :]] if args.rank == 0 and args.SinglePath: logging.info('epoch %d', epoch) logging.info(model.log_alpha) logging.info(F.softmax(model.log_alpha, dim=-1)) logging.info('flops %fM', model.cal_flops()) # train for one epoch if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms: train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging) else: train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, writer, logging) if args.gen_max_child: args.gen_max_child_flag = True prec1 = validate(val_loader, model, criterion, epoch, writer, logging) args.gen_max_child_flag = False if rank == 0: # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint(model_dir, { 'epoch': epoch + 1, 'model': args.model, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best)
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu args.rank = args.rank * ngpus_per_node + gpu if args.distributed: print(args.backend, args.world_size, args.rank) dist.init_process_group(backend=args.backend, init_method='tcp://127.0.0.1:6668', world_size=args.world_size, rank=args.rank) print('Enabled distributed training.') # create model print("=> creating model '{}'".format(args.model)) model = models.__dict__[args.model](N=args.N, M=args.M) torch.cuda.set_device(args.gpu) ipClass = PruningMethodTransposableBlockL1(block_size=args.M, topk=args.N) if args.load_mask: load_state_and_masks(model, args) print("Masks loaded!") else: for n, m in model.named_modules(): if isinstance(m, SparseConvTranspose) or isinstance( m, SparseLinearTranspose): # m.maskBuff.data = ipClass.compute_mask(m.weight, torch.ones_like(m.weight)) setattr( m.weight, "mask", ipClass.compute_mask(m.weight, torch.ones_like(m.weight))) if args.save_mask: save_masks(model, args) print("Masks saved!") model.cuda(args.gpu) #args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) #broadcast_params(model) print(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if args.sparse_optimizer: optimizer = sparse_optimizer.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) # auto resume from a checkpoint model_dir = args.model_dir start_epoch = 0 best_prec1 = 0 if args.rank == 0 and not os.path.exists(model_dir): os.makedirs(model_dir) if args.evaluate: load_state_ckpt(args.checkpoint_path, model) else: best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer) if args.rank == 0 or not args.distributed: writer = SummaryWriter(model_dir) else: writer = None cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( args.train_root, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ColorAugmentation(), normalize, ])) val_dataset = datasets.ImageFolder( args.val_root, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset) else: train_sampler = None val_sampler = None train_loader = DataLoader(train_dataset, batch_size=args.batch_size // args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) val_loader = DataLoader(val_dataset, batch_size=args.batch_size // args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler) if args.evaluate: validate(val_loader, model, criterion, 0, writer) return niters = len(train_loader) lr_scheduler = LRScheduler(optimizer, niters, args) for epoch in range(start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) # train for one epoch train(train_loader, model, criterion, optimizer, lr_scheduler, epoch, writer, args) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, writer, args) if args.rank == 0: # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( model_dir, { 'epoch': epoch + 1, 'model': args.model, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best)
def main(): global args, best_prec1 args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) for key in config: for k, v in config[key].items(): setattr(args, k, v) rank, world_size = init_dist(backend='nccl', port=args.port) args.rank = rank args.world_size = world_size # create model model = ArcFaceWithLoss(args.backbone, args.class_num, args.norm_func, args.embedding_size, args.use_se) model.cuda() broadcast_params(model) optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) # auto resume from a checkpoint model_dir = args.model_dir start_epoch = 0 if args.rank == 0 and not os.path.exists(model_dir): os.makedirs(model_dir) best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer) if args.rank == 0: writer = SummaryWriter(model_dir) else: writer = None cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_dataset = FaceDataset( True, args, transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = BigdataSampler( train_dataset, num_sub_epochs=2, finegrain_factor=10000, seed=1000, ) train_loader = DataLoader(train_dataset, batch_size=args.batch_size // args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) niters = len(train_loader) lr_scheduler = LRScheduler(optimizer, niters, args) for epoch in range(start_epoch, args.epochs): train(train_loader, model, optimizer, lr_scheduler, epoch, writer) if rank == 0: save_checkpoint( model_dir, { 'epoch': epoch + 1, 'model': args.backbone, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, False)
def worker(rank, world_size, args): # pylint: disable=too-many-statements if rank == 0: save_dir = os.path.join(args.save, args.arch, "b{}".format(args.batch_size * world_size)) if not os.path.exists(save_dir): os.makedirs(save_dir) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(save_dir, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) if world_size > 1: # Initialize distributed process group logging.info("init distributed process group {} / {}".format( rank, world_size)) dist.init_process_group( master_ip="localhost", master_port=23456, world_size=world_size, rank=rank, dev=rank, ) save_dir = os.path.join(args.save, args.arch) if rank == 0: prefixs = ['train', 'valid'] writers = { prefix: SummaryWriter(os.path.join(args.output, prefix)) for prefix in prefixs } model = getattr(M, args.arch)() step_start = 0 # if args.model: # logging.info("load weights from %s", args.model) # model.load_state_dict(mge.load(args.model)) # step_start = int(args.model.split("-")[1].split(".")[0]) optimizer = optim.SGD( get_parameters(model), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) # Define train and valid graph def train_func(image, label): model.train() logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) optimizer.backward(loss) # compute gradients if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss) / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1) / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5) / dist.get_world_size() return loss, acc1, acc5 def valid_func(image, label): model.eval() logits = model(image) loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) acc1, acc5 = F.accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.all_reduce_sum(loss) / dist.get_world_size() acc1 = dist.all_reduce_sum(acc1) / dist.get_world_size() acc5 = dist.all_reduce_sum(acc5) / dist.get_world_size() return loss, acc1, acc5 # Build train and valid datasets logging.info("preparing dataset..") transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = datasets.ImageNet(split='train', transform=transform) train_sampler = torch.utils.data.RandomSampler(train_dataset) train_queue = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, shuffle=False, drop_last=True, pin_memory=True, num_workers=args.workers) train_queue = iter(train_queue) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) valid_dataset = datasets.ImageNet(split='val', transform=transform) valid_sampler = torch.utils.data.SequentialSampler(valid_dataset) valid_queue = torch.utils.data.DataLoader(valid_dataset, batch_size=100, sampler=valid_sampler, shuffle=False, drop_last=False, num_workers=args.workers) # Start training objs = AverageMeter("Loss") top1 = AverageMeter("Acc@1") top5 = AverageMeter("Acc@5") total_time = AverageMeter("Time") t = time.time() best_valid_acc = 0 for step in range(step_start, args.steps + 1): # Linear learning rate decay decay = 1.0 decay = 1 - float(step) / args.steps if step < args.steps else 0 for param_group in optimizer.param_groups: param_group["lr"] = args.learning_rate * decay image, label = next(train_queue) time_data = time.time() - t # image = image.astype("float32") # label = label.astype("int32") n = image.shape[0] optimizer.zero_grad() loss, acc1, acc5 = train_func(image, label) optimizer.step() top1.update(100 * acc1.numpy()[0], n) top5.update(100 * acc5.numpy()[0], n) objs.update(loss.numpy()[0], n) total_time.update(time.time() - t) time_iter = time.time() - t t = time.time() if step % args.report_freq == 0 and rank == 0: logging.info( "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f", step, args.learning_rate * decay, float(objs.__str__().split()[1]), 1 - float(top1.__str__().split()[1]) / 100, 1 - float(top5.__str__().split()[1]) / 100, time_data, time_iter - time_data, time_iter * (args.steps - step) / 3600, ) writers['train'].add_scalar('loss', float(objs.__str__().split()[1]), global_step=step) writers['train'].add_scalar('top1_err', 1 - float(top1.__str__().split()[1]) / 100, global_step=step) writers['train'].add_scalar('top5_err', 1 - float(top5.__str__().split()[1]) / 100, global_step=step) objs.reset() top1.reset() top5.reset() total_time.reset() if step % 10000 == 0 and step != 0: loss, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) logging.info( "TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f", step, loss, 1 - valid_acc / 100, 1 - valid_acc5 / 100) is_best = valid_acc > best_valid_acc best_valid_acc = max(valid_acc, best_valid_acc) if rank == 0: writers['valid'].add_scalar('loss', loss, global_step=step) writers['valid'].add_scalar('top1_err', 1 - valid_acc / 100, global_step=step) writers['valid'].add_scalar('top5_err', 1 - valid_acc5 / 100, global_step=step) logging.info("SAVING %06d", step) save_checkpoint( save_dir, { 'step': step + 1, 'model': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_valid_acc, 'optimizer': optimizer.state_dict(), }, is_best)