def __init__(self, args): self.args = args self.max_epochs = args.max_epochs self.select_num = args.select_num self.population_num = args.population_num self.m_prob = args.m_prob self.crossover_num = args.crossover_num self.mutation_num = args.mutation_num self.flops_limit = args.flops_limit self.exp_name = args.exp_name # with open('sn_custom_nets_01_31.pkl', 'rb') as f: # put in correct file name # self.custom_cands = list(pickle.load(f)) self.model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes) self.model = torch.nn.DataParallel(self.model).cuda() supernet_state_dict = torch.load( '../Supernet/models/' + self.exp_name + '/checkpoint-latest.pth.tar')['state_dict'] self.model.load_state_dict(supernet_state_dict) self.log_dir = args.log_dir self.checkpoint_name = self.log_dir + '/' + self.exp_name + '/checkpoint.pth.tar' self.memory = [] self.vis_dict = {} self.keep_top_k = {self.select_num: [], 50: []} self.epoch = 0 self.candidates = [] self.nr_layer = 20 self.nr_state = 4
def test_supernet(): """ Test supernet(network.py) """ from network import ShuffleNetV2_OneShot, get_channel_mask stage_repeats = [4, 8, 4, 4] stage_out_channels = [64, 160, 320, 640] candidate_scales = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0] architecture = [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2] architecture = mxnet.nd.array(architecture).astype(dtype='float32', copy=False) channel_choice = (4, ) * 20 channel_mask = get_channel_mask(channel_choice, stage_repeats, stage_out_channels, candidate_scales, dtype='float32') model = ShuffleNetV2_OneShot() print(model) model.hybridize() model._initialize(ctx=mxnet.cpu()) test_data = mxnet.nd.random.uniform(-1, 1, shape=(5, 3, 224, 224)) test_outputs = model(test_data, architecture, channel_mask) print(test_outputs.shape) model.collect_params().save('supernet.params')
def __init__(self, args): self.args = args self.max_epochs = args.max_epochs self.select_num = args.select_num self.population_num = args.population_num self.m_prob = args.m_prob self.crossover_num = args.crossover_num self.mutation_num = args.mutation_num self.flops_limit = args.flops_limit self.model = ShuffleNetV2_OneShot() self.model = torch.nn.DataParallel(self.model).cuda() supernet_state_dict = torch.load( '../Supernet/models/checkpoint-latest.pth.tar')['state_dict'] self.model.load_state_dict(supernet_state_dict) self.log_dir = args.log_dir self.checkpoint_name = os.path.join(self.log_dir, 'checkpoint.pth.tar') self.memory = [] self.vis_dict = {} self.keep_top_k = {self.select_num: [], 50: []} self.epoch = 0 self.candidates = [] self.nr_layer = 20 self.nr_state = 4
def main(): args = get_args() assert args.exp_name is not None splits = [] start = 0 k = 800 end = k for i in range(5): splits += [(start, end)] start += k end += k print(splits) # cands = generate_cand_list(12000) # pickle.dump(cands, open( "../data/cl3_2_1.p", "wb" ) ) # candidate_list = pickle.load(open("../data/cl3_2_1.p", "rb")) candidate_list = pickle.load(open("/home/bg141/SinglePathOneShot/src/data/loc_data_ed_15.p", "rb")) # candidate_list = [np.fromstring(c[1:-1], dtype=int, sep=',').tolist() for c in candidate_list] model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes) model = nn.DataParallel(model) device = torch.device("cuda") # model = model.to(device) model = model.cuda() lastest_model, iters = get_lastest_model(args.exp_name) print("Iters: ", iters) if lastest_model is not None: all_iters = iters checkpoint = torch.load(lastest_model) model.load_state_dict(checkpoint['state_dict'], strict=True) print('load from checkpoint') err_list = [] cand_list = [] print("Split: ", splits[args.gpu]) i = 0 for cand in candidate_list[splits[args.gpu][0]:splits[args.gpu][1]]: err = get_cand_err(model, cand, args) err_list += [err] cand_list += [cand] i += 1 print("Net: ", i) # if i%500 == 0: # pickle.dump(err_list, open("./data/err-"+args.exp_name+"-"+str(args.gpu)+"-"+str(i)+".p", "wb")) # pickle.dump(cand_list, open("./data/cand-"+args.exp_name+"-"+str(args.gpu)+"-"+str(i)+".p", "wb")) pickle.dump(err_list, open( "./data/err-"+args.exp_name+"-"+str(args.gpu)+"_ed_15.p", "wb" ) ) pickle.dump(cand_list, open("./data/cand-"+args.exp_name+"-"+str(args.gpu)+"_ed_15.p", "wb")) print("Finished") return
def test_load_supernet_params(): """ Testing the load of supernet's params """ from network import ShuffleNetV2_OneShot import mxnet model = ShuffleNetV2_OneShot(search=True) model.collect_params().load('supernet.params', ctx=mxnet.cpu(), cast_dtype=True, dtype_source='saved') print("Done!")
def main3(): args = get_args() assert args.exp_name is not None if not os.path.exists('./data/' + args.exp_name + "/"): os.mkdir('./data/' + args.exp_name ) # Build candidate list get_random_cand = lambda:tuple(np.random.randint(4) for i in range(20)) flops_l, flops_r, flops_step = 290, 360, 50 bins = [[i, i+flops_step] for i in range(flops_l, flops_r, flops_step)] def get_uniform_sample_cand(*,timeout=500): idx = np.random.randint(len(bins)) l, r = bins[idx] for i in range(timeout): cand = get_random_cand() if l*1e6 <= get_cand_flops(cand) <= r*1e6: return cand print("timeout") return get_random_cand() model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes) model = nn.DataParallel(model) device = torch.device("cuda") # model = model.to(device) model = model.cuda() lastest_model, iters = get_lastest_model(args.exp_name) if lastest_model is not None: all_iters = iters checkpoint = torch.load(lastest_model) model.load_state_dict(checkpoint['state_dict'], strict=True) print('load from checkpoint') err_list = [] cand_list = [] i = 0 print("GPU: ", args.gpu) for i in range(5000): cand = get_uniform_sample_cand() err = get_cand_err(model, cand, args) err_list += [err] cand_list += [cand] print("Net: ", i) if i%500 == 0: pickle.dump(err_list, open("./data/"+args.exp_name+"/err-"+str(args.gpu)+"-"+str(i)+".p", "wb")) pickle.dump(cand_list, open("./data/"+args.exp_name+"/cand-"+str(args.gpu)+"-"+str(i)+".p", "wb")) pickle.dump(err_list, open("./data/"+args.exp_name+"/err-"+str(args.gpu)+".p", "wb")) pickle.dump(cand_list, open("./data/"+args.exp_name+"/cand-"+str(args.gpu)+".p", "wb")) print("Finished") return
def __init__(self, args): self.args = args self.context = [mx.gpu(int(gpu)) for gpu in args.gpus.split(',')] if len(args.gpus.split(',')) > 0 else [mx.cpu()] for ctx in self.context: mx.random.seed(self.args.random_seed, ctx=ctx) np.random.seed(self.args.random_seed) random.seed(self.args.random_seed) num_gpus = len(self.args.gpus.split(',')) batch_size = max(1, num_gpus) * self.args.batch_size if self.args.use_rec: if self.args.use_dali: self.train_data = dali.get_data_rec((3, self.args.input_size, self.args.input_size), self.args.crop_ratio, self.args.rec_train, self.args.rec_train_idx, self.args.batch_size, num_workers=2, train=True, shuffle=True, backend='dali-gpu', gpu_ids=[0,1], kv_store='nccl', dtype=opt.dtype, input_layout='NCHW') self.val_data = dali.get_data_rec((3, self.args.input_size, self.args.input_size), self.args.crop_ratio, self.args.rec_val, self.args.rec_val_idx, self.args.batch_size, num_workers=2, train=False, shuffle=False, backend='dali-gpu', gpu_ids=[0,1], kv_store='nccl', dtype=opt.dtype, input_layout='NCHW') self.batch_fn = batch_fn else: self.train_data, self.val_data, self.batch_fn = get_data_rec(self.args.rec_train, self.args.rec_train_idx, self.args.rec_val, self.args.rec_val_idx, batch_size, self.args.num_workers, self.args.random_seed) else: self.train_data, self.val_data, self.batch_fn = get_data_loader(self.args.data_dir, batch_size, self.args.num_workers) self.model = ShuffleNetV2_OneShot(search=True) self.model.collect_params().load(self.args.resume_params, ctx=self.context, cast_dtype=True, dtype_source='saved') self.memory = [] self.vis_dict = {} self.keep_top_k = {self.args.select_num: [], 50: []} self.epoch = 0 self.candidates = [] self.nr_layer = 20 self.nr_state = 4 self.channel_state = 10# len(candidate_scales)
def test_subnet(): """ Test subnet(subnet.py) """ from subnet import ShuffleNetV2_OneShot block_choice = (0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2) channel_choice = (6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4, 3) model = ShuffleNetV2_OneShot(input_size=224, n_class=1000, architecture=block_choice, channels_idx=channel_choice, act_type='relu', search=False) # define a specific subnet model.hybridize() model._initialize(ctx=mxnet.cpu()) print(model) test_data = mxnet.nd.random.uniform(-1, 1, shape=(5, 3, 224, 224)) test_outputs = model(test_data) print(test_outputs.shape)
def main(): args = get_args() # Log log_format = '[%(asctime)s] %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%d %I:%M:%S') t = time.time() local_time = time.localtime(t) if not os.path.exists('./log'): os.mkdir('./log') fh = logging.FileHandler( os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t))) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) use_gpu = False if torch.cuda.is_available(): use_gpu = True if args.cifar10 == False: assert os.path.exists(args.train_dir) train_dataset = datasets.ImageFolder( args.train_dir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(0.5), ToBGRTensor(), ])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=use_gpu) train_dataprovider = DataIterator(train_loader) assert os.path.exists(args.val_dir) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( args.val_dir, transforms.Compose([ OpencvResize(256), transforms.CenterCrop(224), ToBGRTensor(), ])), batch_size=200, shuffle=False, num_workers=1, pin_memory=use_gpu) val_dataprovider = DataIterator(val_loader) print('load imagenet data successfully') else: train_transform, valid_transform = data_transforms(args) trainset = torchvision.datasets.CIFAR10(root=os.path.join( args.data_dir, 'cifar'), train=True, download=True, transform=train_transform) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8) train_dataprovider = DataIterator(train_loader) valset = torchvision.datasets.CIFAR10(root=os.path.join( args.data_dir, 'cifar'), train=False, download=True, transform=valid_transform) val_loader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8) val_dataprovider = DataIterator(val_loader) print('load cifar10 data successfully') model = ShuffleNetV2_OneShot() optimizer = torch.optim.SGD(get_parameters(model), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1) if use_gpu: model = nn.DataParallel(model) loss_function = criterion_smooth.cuda() device = torch.device("cuda") else: loss_function = criterion_smooth device = torch.device("cpu") scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (1.0 - step / args.total_iters) if step <= args.total_iters else 0, last_epoch=-1) model = model.to(device) all_iters = 0 if args.auto_continue: lastest_model, iters = get_lastest_model() if lastest_model is not None: all_iters = iters checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint['state_dict'], strict=True) print('load from checkpoint') for i in range(iters): scheduler.step() args.optimizer = optimizer args.loss_function = loss_function args.scheduler = scheduler args.train_dataprovider = train_dataprovider args.val_dataprovider = val_dataprovider if args.eval: if args.eval_resume is not None: checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint, strict=True) validate(model, device, args, all_iters=all_iters) exit(0) while all_iters < args.total_iters: all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
def main(): args = get_args() # Log log_format = '[%(asctime)s] %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%d %I:%M:%S') t = time.time() local_time = time.localtime(t) if not os.path.exists('./log'): os.mkdir('./log') fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t))) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) use_gpu = False if torch.cuda.is_available(): use_gpu = True assert os.path.exists(args.train_dir) train_dataset = datasets.ImageFolder( args.train_dir, transforms.Compose([ transforms.RandomResizedCrop(args.im_size), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(0.5), ToBGRTensor(), ]) ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=use_gpu) train_dataprovider = DataIterator(train_loader) assert os.path.exists(args.val_dir) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(args.val_dir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(args.im_size), ToBGRTensor(), ])), batch_size=200, shuffle=False, num_workers=8, pin_memory=use_gpu ) val_dataprovider = DataIterator(val_loader) print('load data successfully') arch_path='arch.pkl' if os.path.exists(arch_path): with open(arch_path,'rb') as f: architecture=pickle.load(f) else: raise NotImplementedError channels_scales = (1.0,)*20 model = ShuffleNetV2_OneShot(architecture=architecture, channels_scales=channels_scales, n_class=args.num_classes, input_size=args.im_size) print('flops:',get_flops(model)) optimizer = torch.optim.SGD(get_parameters(model), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion_smooth = CrossEntropyLabelSmooth(args.num_classes, 0.1) if use_gpu: # model = nn.DataParallel(model) loss_function = criterion_smooth.cuda() device = torch.device("cuda") else: loss_function = criterion_smooth device = torch.device("cpu") scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1) # model = model.to(device) model = model.cuda() all_iters = 0 if args.auto_continue: lastest_model, iters = get_lastest_model() if lastest_model is not None: all_iters = iters checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint['state_dict'], strict=True) print('load from checkpoint') for i in range(iters): scheduler.step() args.optimizer = optimizer args.loss_function = loss_function args.scheduler = scheduler args.train_dataprovider = train_dataprovider args.val_dataprovider = val_dataprovider if args.eval: if args.eval_resume is not None: checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint, strict=True) validate(model, device, args, all_iters=all_iters) exit(0) t = time.time() while all_iters < args.total_iters: all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters) validate(model, device, args, all_iters=all_iters) # all_iters = train(model, device, args, val_interval=int(1280000/args.batch_size), bn_process=True, all_iters=all_iters) validate(model, device, args, all_iters=all_iters) save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-') print("Finished {} iters in {:.3f} seconds".format(all_iters, time.time()-t))
def main(): global args, best_prec1 args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) for key in config: for k, v in config[key].items(): setattr(args, k, v) print('Enabled distributed training.') rank, world_size = init_dist( backend='nccl', port=args.port) args.rank = rank args.world_size = world_size np.random.seed(args.seed*args.rank) torch.manual_seed(args.seed*args.rank) torch.cuda.manual_seed(args.seed*args.rank) torch.cuda.manual_seed_all(args.seed*args.rank) print('random seed: ', args.seed*args.rank) # create model print("=> creating model '{}'".format(args.model)) if args.SinglePath: architecture = 20*[0] channels_scales = 20*[1.0] model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales) model.cuda() broadcast_params(model) for v in model.parameters(): if v.requires_grad: if v.grad is None: v.grad = torch.zeros_like(v) model.log_alpha.grad = torch.zeros_like(model.log_alpha) criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda() wo_wd_params = [] wo_wd_param_names = [] network_params = [] network_param_names = [] for name, mod in model.named_modules(): if isinstance(mod, nn.BatchNorm2d): for key, value in mod.named_parameters(): wo_wd_param_names.append(name+'.'+key) for key, value in model.named_parameters(): if key != 'log_alpha': if value.requires_grad: if key in wo_wd_param_names: wo_wd_params.append(value) else: network_params.append(value) network_param_names.append(key) params = [ {'params': network_params, 'lr': args.base_lr, 'weight_decay': args.weight_decay }, {'params': wo_wd_params, 'lr': args.base_lr, 'weight_decay': 0.}, ] param_names = [network_param_names, wo_wd_param_names] if args.rank == 0: print('>>> params w/o weight decay: ', wo_wd_param_names) optimizer = torch.optim.SGD(params, momentum=args.momentum) if args.SinglePath: arch_optimizer = torch.optim.Adam( [param for name, param in model.named_parameters() if name == 'log_alpha'], lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay ) # auto resume from a checkpoint remark = 'imagenet_' remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(args.base_lr) + '_seed_' + str(args.seed) + '_pretrain_' + str(args.pretrain_epoch) if args.early_fix_arch: remark += '_early_fix_arch' if args.flops_loss: remark += '_flops_loss_' + str(args.flops_loss_coef) if args.remark != 'none': remark += '_'+args.remark args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark) args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark) generate_date = str(datetime.now().date()) path = os.path.join(generate_date, args.save) if args.rank == 0: log_format = '%(asctime)s %(message)s' utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py')) logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(path, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", args) writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log) else: writer = None model_dir = path start_epoch = 0 if args.evaluate: load_state_ckpt(args.checkpoint_path, model) else: best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer) cudnn.benchmark = True cudnn.enabled = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_wo_ms = ImagenetDataset( args.train_root, args.train_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) val_dataset = ImagenetDataset( args.val_root, args.val_source, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset) train_loader = DataLoader( train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) train_loader_wo_ms = DataLoader( train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) val_loader = DataLoader( val_dataset, batch_size=50, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler) if args.evaluate: validate(val_loader, model, criterion, 0, writer, logging) return niters = len(train_loader) lr_scheduler = LRScheduler(optimizer, niters, args) for epoch in range(start_epoch, 85): train_sampler.set_epoch(epoch) if args.early_fix_arch: if len(model.fix_arch_index.keys()) > 0: for key, value_lst in model.fix_arch_index.items(): model.log_alpha.data[key, :] = value_lst[1] sort_log_alpha = torch.topk(F.softmax(model.log_alpha.data, dim=-1), 2) argmax_index = (sort_log_alpha[0][:,0] - sort_log_alpha[0][:,1] >= 0.3) for id in range(argmax_index.size(0)): if argmax_index[id] == 1 and id not in model.fix_arch_index.keys(): model.fix_arch_index[id] = [sort_log_alpha[1][id,0].item(), model.log_alpha.detach().clone()[id, :]] if args.rank == 0 and args.SinglePath: logging.info('epoch %d', epoch) logging.info(model.log_alpha) logging.info(F.softmax(model.log_alpha, dim=-1)) logging.info('flops %fM', model.cal_flops()) # train for one epoch if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms: train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging) else: train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, writer, logging) if args.gen_max_child: args.gen_max_child_flag = True prec1 = validate(val_loader, model, criterion, epoch, writer, logging) args.gen_max_child_flag = False if rank == 0: # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint(model_dir, { 'epoch': epoch + 1, 'model': args.model, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best)
def main(): args = get_args() # Log log_format = '[%(asctime)s] %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%d %I:%M:%S') t = time.time() local_time = time.localtime(t) if not os.path.exists('./log'): os.mkdir('./log') fh = logging.FileHandler( os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t))) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) use_gpu = False if torch.cuda.is_available(): use_gpu = True assert os.path.exists(args.train_dir) train_dataset = datasets.ImageFolder( args.train_dir, transforms.Compose([ transforms.RandomResizedCrop(96), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(0.5), ToBGRTensor(), ])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=use_gpu) train_dataprovider = DataIterator(train_loader) assert os.path.exists(args.val_dir) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder( args.val_dir, transforms.Compose([ OpencvResize(96), # transforms.CenterCrop(96), ToBGRTensor(), ])), batch_size=200, shuffle=False, num_workers=4, pin_memory=use_gpu) val_dataprovider = DataIterator(val_loader) arch_path = 'cl400.p' if os.path.exists(arch_path): with open(arch_path, 'rb') as f: architectures = pickle.load(f) else: raise NotImplementedError channels_scales = (1.0, ) * 20 cands = {} splits = [(i, 10 + i) for i in range(0, 400, 10)] architectures = np.array(architectures) architectures = architectures[ splits[args.split_num][0]:splits[args.split_num][1]] print(len(architectures)) logging.info("Training and Validating arch: " + str(splits[args.split_num])) for architecture in architectures: architecture = tuple(architecture.tolist()) model = ShuffleNetV2_OneShot(architecture=architecture, channels_scales=channels_scales, n_class=10, input_size=96) print('flops:', get_flops(model)) optimizer = torch.optim.SGD(get_parameters(model), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1) if use_gpu: model = nn.DataParallel(model) loss_function = criterion_smooth.cuda() device = torch.device("cuda") else: loss_function = criterion_smooth device = torch.device("cpu") scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (1.0 - step / args.total_iters) if step <= args.total_iters else 0, last_epoch=-1) model = model.to(device) all_iters = 0 if args.auto_continue: lastest_model, iters = get_lastest_model() if lastest_model is not None: all_iters = iters checkpoint = torch.load( lastest_model, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint['state_dict'], strict=True) print('load from checkpoint') for i in range(iters): scheduler.step() args.optimizer = optimizer args.loss_function = loss_function args.scheduler = scheduler args.train_dataprovider = train_dataprovider args.val_dataprovider = val_dataprovider # print("BEGIN VALDATE: ", args.eval, args.eval_resume) if args.eval: if args.eval_resume is not None: checkpoint = torch.load( args.eval_resume, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint, strict=True) validate(model, device, args, all_iters=all_iters) exit(0) # t1,t5 = validate(model, device, args, all_iters=all_iters) # print("VALDATE: ", t1, " ", t5) while all_iters < args.total_iters: all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters) validate(model, device, args, all_iters=all_iters) all_iters = train(model, device, args, val_interval=int(1280000 / args.batch_size), bn_process=True, all_iters=all_iters) top1, top5 = validate(model, device, args, all_iters=all_iters) save_checkpoint({ 'state_dict': model.state_dict(), }, args.total_iters, tag='bnps-') cands[architecture] = [top1, top5] pickle.dump( cands, open("from_scratch_split_{}.pkl".format(args.split_num), 'wb'))
def train(model, meta_model, device, args, *, val_interval, bn_process=False, all_iters=None): optimizer = args.optimizer meta_optimizer = args.meta_optimizer loss_function = args.loss_function scheduler = args.scheduler meta_scheduler = args.meta_scheduler train_dataprovider = args.train_dataprovider t1 = time.time() Top1_err, Top5_err = 0.0, 0.0 model.train() meta_model.train() for iters in range(1, val_interval + 1): scheduler.step() meta_scheduler.step() if bn_process: adjust_bn_momentum(model, iters) adjust_bn_momentum(meta_model, iters) all_iters += 1 d_st = time.time() data, target = train_dataprovider.next() target = target.type(torch.LongTensor) data, target = data.to(device), target.to(device) data_time = time.time() - d_st get_random_cand = lambda: tuple( np.random.randint(4) for i in range(20)) flops_l, flops_r, flops_step = 290, 360, 10 bins = [[i, i + flops_step] for i in range(flops_l, flops_r, flops_step)] def get_uniform_sample_cand(*, timeout=500): idx = np.random.randint(len(bins)) l, r = bins[idx] for i in range(timeout): cand = get_random_cand() if l * 1e6 <= get_cand_flops(cand) <= r * 1e6: return cand return get_random_cand() if iters % 5 == 1: cand = get_uniform_sample_cand() output = meta_model(data, cand) loss = loss_function(output, target) optimizer.zero_grad() meta_optimizer.zero_grad() loss.backward() for p in meta_model.parameters(): if p.grad is not None and p.grad.sum() == 0: p.grad = None if iters % 5 != 0: # step 1: update submodel meta_optimizer.step() else: # step 2: update original model # # copy gradient to original model for p, q in zip(model.parameters(), meta_model.parameters()): if q.grad is not None: p.grad = q.grad.clone() # # check for p, q in zip(model.parameters(), meta_model.parameters()): if q.grad is not None: assert torch.all(torch.eq(p.grad, q.grad)) # # update weight optimizer.step() # load weight to submodel meta_model.load_state_dict(model.state_dict()) # check for p, q in zip(model.parameters(), meta_model.parameters()): if p is not None: assert torch.all(torch.eq(q, p)) prec1, prec5 = accuracy(output, target, topk=(1, 5)) Top1_err += 1 - prec1.item() / 100 Top5_err += 1 - prec5.item() / 100 if all_iters % args.display_interval == 0: printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \ 'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \ 'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \ 'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval) logging.info(printInfo) t1 = time.time() Top1_err, Top5_err = 0.0, 0.0 if all_iters % args.save_interval == 0 or all_iters == 1: save_checkpoint({ 'state_dict': model.state_dict(), }, all_iters) checkpoint = torch.load(args.eval_resume) val_model = ShuffleNetV2_OneShot() val_model = nn.DataParallel(val_model) val_model = val_model.to(device) val_model.load_state_dict(checkpoint['state_dict'], strict=True) validate(val_model, device, args, all_iters=all_iters) return all_iters
def main(): opt = parse_args() makedirs(opt.log_dir) filehandler = logging.FileHandler(opt.log_dir + '/' + opt.logging_file) streamhandler = logging.StreamHandler() logger = logging.getLogger('') logger.setLevel(logging.INFO) logger.addHandler(filehandler) logger.addHandler(streamhandler) logger.info(opt) batch_size = opt.batch_size classes = 1000 num_training_samples = 1281167 num_gpus = opt.num_gpus batch_size *= max(1, num_gpus) context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] for ctx in context: mx.random.seed(seed_state=opt.random_seed, ctx=ctx) np.random.seed(opt.random_seed) random.seed(opt.random_seed) num_workers = opt.num_workers lr_decay = opt.lr_decay lr_decay_period = opt.lr_decay_period if opt.lr_decay_period > 0: lr_decay_epoch = list( range(lr_decay_period, opt.num_epochs, lr_decay_period)) else: lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch] num_batches = num_training_samples // batch_size lr_scheduler = LRSequential([ LRScheduler('linear', base_lr=0, target_lr=opt.lr, nepochs=opt.warmup_epochs, iters_per_epoch=num_batches), LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0, nepochs=opt.num_epochs - opt.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=lr_decay, power=2) ]) sw = SummaryWriter(logdir=opt.log_dir, flush_secs=5, verbose=False) optimizer = 'sgd' optimizer_params = { 'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler } if opt.dtype != 'float32': optimizer_params['multi_precision'] = True net = ShuffleNetV2_OneShot() net.cast(opt.dtype) if opt.mode == 'hybrid': net.hybridize() if opt.resume_params is not '': net.load_parameters(opt.resume_params, ctx=context) # Two functions for reading data from record file or raw images def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers, seed): rec_train = os.path.expanduser(rec_train) rec_train_idx = os.path.expanduser(rec_train_idx) rec_val = os.path.expanduser(rec_val) rec_val_idx = os.path.expanduser(rec_val_idx) jitter_param = 0.4 lighting_param = 0.1 input_size = opt.input_size crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 resize = int(math.ceil(input_size / crop_ratio)) mean_rgb = [123.68, 116.779, 103.939] std_rgb = [58.393, 57.12, 57.375] def batch_fn(batch, ctx): data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) return data, label train_data = mx.io.ImageRecordIter( path_imgrec=rec_train, path_imgidx=rec_train_idx, preprocess_threads=num_workers, shuffle=True, batch_size=batch_size, data_shape=(3, input_size, input_size), mean_r=mean_rgb[0], mean_g=mean_rgb[1], mean_b=mean_rgb[2], std_r=std_rgb[0], std_g=std_rgb[1], std_b=std_rgb[2], rand_mirror=True, random_resized_crop=True, max_aspect_ratio=4. / 3., min_aspect_ratio=3. / 4., max_random_area=1, min_random_area=0.08, brightness=jitter_param, saturation=jitter_param, contrast=jitter_param, pca_noise=lighting_param, shuffle_chunk_seed=seed, seed=seed, seed_aug=seed, ) val_data = mx.io.ImageRecordIter( path_imgrec=rec_val, path_imgidx=rec_val_idx, preprocess_threads=num_workers, shuffle=False, batch_size=batch_size, resize=resize, data_shape=(3, input_size, input_size), mean_r=mean_rgb[0], mean_g=mean_rgb[1], mean_b=mean_rgb[2], std_r=std_rgb[0], std_g=std_rgb[1], std_b=std_rgb[2], ) return train_data, val_data, batch_fn def get_data_loader(data_dir, batch_size, num_workers): normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) jitter_param = 0.4 lighting_param = 0.1 input_size = opt.input_size crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 resize = int(math.ceil(input_size / crop_ratio)) def batch_fn(batch, ctx): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) return data, label transform_train = transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomFlipLeftRight(), transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param, saturation=jitter_param), transforms.RandomLighting(lighting_param), transforms.ToTensor(), normalize ]) transform_test = transforms.Compose([ transforms.Resize(resize, keep_ratio=True), transforms.CenterCrop(input_size), transforms.ToTensor(), normalize ]) train_data = gluon.data.DataLoader(imagenet.classification.ImageNet( data_dir, train=True).transform_first(transform_train), batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_data = gluon.data.DataLoader(imagenet.classification.ImageNet( data_dir, train=False).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_data, val_data, batch_fn if opt.use_rec: if opt.use_dali: train_data = dali.get_data_rec((3, opt.input_size, opt.input_size), opt.crop_ratio, opt.rec_train, opt.rec_train_idx, opt.batch_size, num_workers=2, train=True, shuffle=True, backend='dali-gpu', gpu_ids=[0, 1], kv_store='nccl', dtype=opt.dtype, input_layout='NCHW') val_data = dali.get_data_rec((3, opt.input_size, opt.input_size), opt.crop_ratio, opt.rec_val, opt.rec_val_idx, opt.batch_size, num_workers=2, train=False, shuffle=False, backend='dali-gpu', gpu_ids=[0, 1], kv_store='nccl', dtype=opt.dtype, input_layout='NCHW') def batch_fn(batch, ctx): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) return data, label else: train_data, val_data, batch_fn = get_data_rec( opt.rec_train, opt.rec_train_idx, opt.rec_val, opt.rec_val_idx, batch_size, num_workers, opt.random_seed) else: train_data, val_data, batch_fn = get_data_loader( opt.data_dir, batch_size, num_workers) if opt.mixup: train_metric = mx.metric.RMSE() else: train_metric = mx.metric.Accuracy() acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) save_frequency = opt.save_frequency if opt.save_dir and save_frequency: save_dir = opt.save_dir makedirs(save_dir) else: save_dir = '' save_frequency = 0 def mixup_transform(label, classes, lam=1, eta=0.0): if isinstance(label, nd.NDArray): label = [label] res = [] for l in label: y1 = l.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) y2 = l[::-1].one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) res.append(lam * y1 + (1 - lam) * y2) return res def smooth(label, classes, eta=0.1): if isinstance(label, nd.NDArray): label = [label] smoothed = [] for l in label: res = l.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) smoothed.append(res) return smoothed def test(net, batch_fn, ctx, train_data, val_data, cand, channel_mask, update_images=20000, update_bn=False): if update_bn: if opt.use_rec: train_data.reset() net.cast('float32') for k, v in net._children.items(): if isinstance(v, BatchNormNAS): v.inference_update_stat = True for i, batch in enumerate(train_data): if (i + 1) * opt.batch_size * len(ctx) >= update_images: break data, _ = batch_fn(train_data) _ = [ net( X.astype('float32', copy=False), cand.as_in_context(X.context).astype('float32', copy=False), channel_mask.as_in_context(X.context).astype( 'float32', copy=False)) for X in data ] for k, v in net._children.items(): if isinstance(v, BatchNormNAS): v.inference_update_stat = False net.cast(opt.dtype) if opt.use_rec: val_data.reset() acc_top1.reset() acc_top5.reset() for i, batch in enumerate(val_data): data, label = batch_fn(batch, ctx) outputs = [ net(X.astype(opt.dtype, copy=False), cand.as_in_context(X.context), channel_mask.as_in_context(X.context)) for X in data ] acc_top1.update(label, outputs) acc_top5.update(label, outputs) _, top1 = acc_top1.get() _, top5 = acc_top5.get() return (top1, top5) def train(ctx): if isinstance(ctx, mx.Context): ctx = [ctx] if opt.resume_params is '': net._initialize(ctx=ctx, force_reinit=True) if opt.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) if opt.resume_states is not '': trainer.load_states(opt.resume_states) if opt.label_smoothing or opt.mixup: sparse_label_loss = False else: sparse_label_loss = True L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss) best_val_score = 0 iteration = 0 for epoch in range(opt.resume_epoch, opt.num_epochs): tic = time.time() if opt.use_rec: train_data.reset() train_metric.reset() btic = time.time() get_random_cand = lambda x: tuple( np.random.randint(x) for i in range(20)) #Get random channel mask if epoch < 10: channel = (9, ) * 20 #Firstly, train full supernet elif epoch < 15: channel = tuple( np.random.randint(2) + 8 for i in range(20)) #the channel choice is 8 ~ 9 elif epoch < 20: channel = tuple( np.random.randint(3) + 7 for i in range(20)) #the channel choice is 7 ~ 9 elif epoch < 25: channel = tuple( np.random.randint(4) + 6 for i in range(20)) #the channel choice is 6 ~ 9 elif epoch < 30: channel = tuple( np.random.randint(5) + 5 for i in range(20)) #the channel choice is 5 ~ 9 elif epoch < 35: channel = tuple( np.random.randint(6) + 4 for i in range(20)) #the channel choice is 4 ~ 9 elif epoch < 45: channel = tuple( np.random.randint(7) + 3 for i in range(20)) #the channel choice is 3 ~ 9 elif epoch < 50: channel = tuple( np.random.randint(8) + 2 for i in range(20)) #the channel choice is 2 ~ 9 elif epoch < 55: channel = tuple( np.random.randint(9) + 1 for i in range(20)) #the channel choice is 1 ~ 9 else: channel = tuple( np.random.randint(10) for i in range(20)) #the channel choice is 0 ~ 9 print('Defined Channel Choice: ', channel) channel_mask = get_channel_mask(channel, stage_repeats, stage_out_channels, candidate_scales, dtype=opt.dtype) for i, batch in enumerate(train_data): # Generate channel mask and random block choice cand = get_random_cand(4) print('Random Block Candidate: ', cand) cand = nd.array(cand) cand = cand.astype(opt.dtype, copy=False) #print(channel_mask) data, label = batch_fn(batch, ctx) if opt.mixup: lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha) if epoch >= opt.num_epochs - opt.mixup_off_epoch: lam = 1 data = [lam * X + (1 - lam) * X[::-1] for X in data] if opt.label_smoothing: eta = 0.1 else: eta = 0.0 label = mixup_transform(label, classes, lam, eta) elif opt.label_smoothing: hard_label = label label = smooth(label, classes) with ag.record(): outputs = [ net(X.astype(opt.dtype, copy=False), cand.as_in_context(X.context), channel_mask.as_in_context(X.context)) for X in data ] loss = [ L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label) ] for l in loss: l.backward() sw.add_scalar(tag='train_loss', value=sum([l.sum().asscalar() for l in loss]) / len(loss), global_step=iteration) trainer.step(batch_size, ignore_stale_grad=True) if opt.mixup: output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \ for out in outputs] train_metric.update(label, output_softmax) else: if opt.label_smoothing: train_metric.update(hard_label, outputs) else: train_metric.update(label, outputs) train_metric_name, train_metric_score = train_metric.get() sw.add_scalar( tag='train_{}_curves'.format(train_metric_name), value=('train_{}_value'.format(train_metric_name), train_metric_score), global_step=iteration) if opt.log_interval and not (i + 1) % opt.log_interval: train_metric_name, train_metric_score = train_metric.get() logger.info( 'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f' % (epoch, i, batch_size * opt.log_interval / (time.time() - btic), train_metric_name, train_metric_score, trainer.learning_rate)) btic = time.time() iteration += 1 if epoch == 0: sw.add_graph(net) train_metric_name, train_metric_score = train_metric.get() throughput = int(batch_size * i / (time.time() - tic)) # Generate channel mask and random block choice cand = get_random_cand(4) cand = nd.array(cand) cand = cand.astype(opt.dtype, copy=False) #channel_mask = get_channel_mask(channel, stage_repeats, stage_out_channels, candidate_scales, dtype=opt.dtype) top1_val_acc, top5_val_acc = test(net, batch_fn, ctx, train_data, val_data, cand, channel_mask, update_images=20000, update_bn=False) sw.add_scalar(tag='val_acc_curves', value=('valid_acc_value', top1_val_acc), global_step=epoch) logger.info('[Epoch %d] training: %s=%f' % (epoch, train_metric_name, train_metric_score)) logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' % (epoch, throughput, time.time() - tic)) logger.info('[Epoch %d] validation: top1_acc=%f top5_acc=%f' % (epoch, top1_val_acc, top5_val_acc)) if top1_val_acc > best_val_score: best_val_score = top1_val_acc net.collect_params().save( '%s/%.4f-supernet_imagenet-%d-best.params' % (save_dir, best_val_score, epoch)) trainer.save_states( '%s/%.4f-supernet_imagenet-%d-best.states' % (save_dir, best_val_score, epoch)) if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: net.collect_params().save('%s/supernet_imagenet-%d.params' % (save_dir, epoch)) trainer.save_states('%s/supernet_imagenet-%d.states' % (save_dir, epoch)) sw.close() if save_frequency and save_dir: net.collect_params().save('%s/supernet_imagenet-%d.params' % (save_dir, opt.num_epochs - 1)) trainer.save_states('%s/supernet_imagenet-%d.states' % (save_dir, opt.num_epochs - 1)) train(context)
def main(): args = get_args() args.world_size = args.gpus * args.nodes args.rank = args.gpus * args.nr + args.local_rank print("RANK: " + str(args.rank) + ", LOCAL RANK: " + str(args.local_rank)) # Log log_format = '[%(asctime)s] %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%d %I:%M:%S') t = time.time() local_time = time.localtime(t) if not os.path.exists('/home/admin/aihub/SinglePathOneShot/log'): os.mkdir('/home/admin/aihub/SinglePathOneShot/log') fh = logging.FileHandler( os.path.join( '/home/admin/aihub/SinglePathOneShot/log/train-{}{:02}{}'.format( local_time.tm_year % 2000, local_time.tm_mon, t))) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) use_gpu = False if torch.cuda.is_available(): use_gpu = True assert os.path.exists(args.train_dir) train_dataset = datasets.ImageFolder( args.train_dir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(0.5), ToBGRTensor(), ])) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.world_size, rank=args.rank) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=32, pin_memory=True, sampler=train_sampler) train_dataprovider = DataIterator(train_loader) assert os.path.exists(args.val_dir) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( args.val_dir, transforms.Compose([ OpencvResize(256), transforms.CenterCrop(224), ToBGRTensor(), ])), batch_size=200, shuffle=False, num_workers=32, pin_memory=use_gpu) val_dataprovider = DataIterator(val_loader) print('load data successfully') dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.local_rank) # dist.init_process_group(backend='nccl', init_method='tcp://'+args.ip+':'+str(args.port), world_size=args.world_size, rank=args.rank) # dist.init_process_group(backend='nccl', init_method="file:///mnt/nas1/share_file", world_size=args.world_size, rank=args.rank) torch.cuda.set_device(args.local_rank) channels_scales = (1.0, ) * 20 model = ShuffleNetV2_OneShot(architecture=list(args.arch), channels_scales=channels_scales) device = torch.device(args.local_rank) model = model.cuda(args.local_rank) optimizer = torch.optim.SGD(get_parameters(model), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (1.0 - step / args.total_iters) if step <= args.total_iters else 0, last_epoch=-1) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], find_unused_parameters=False) #,output_device=args.local_rank) # , loss_function = criterion_smooth.cuda() all_iters = 0 args.optimizer = optimizer args.loss_function = loss_function args.scheduler = scheduler args.train_dataprovider = train_dataprovider args.val_dataprovider = val_dataprovider if args.eval: if args.eval_resume is not None: checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu') model.load_state_dict(checkpoint, strict=True) validate(model, device, args, all_iters=all_iters) exit(0) validate(model, device, args, all_iters=all_iters) while all_iters < args.total_iters: all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters) validate(model, device, args, all_iters=all_iters) all_iters = train(model, device, args, val_interval=int(1280000 / args.val_batch_size), bn_process=True, all_iters=all_iters) validate(model, device, args, all_iters=all_iters)