def init_popu(consts_choose, stages, channel_scales, kernels, seed_len=40, rate=(0.6, 1.05)): arch = resnet_nas(consts_choose) consts = constraint_cal(arch) ga_seed = [] ga_seed.append(consts_choose) while True: if len(ga_seed) >= seed_len: break new_idx = random.choice(list(range(len(ga_seed)))) new_choose = copy.deepcopy(ga_seed[new_idx]) new_choose = mutate(new_choose, stages, channel_scales, kernels, mutate_rate=0.8) if is_satisfy_consts(consts, new_choose, ga_seed, rate): ga_seed.append(new_choose) return ga_seed
def cross_mutation(consts_choose, ga_seed, stages, channel_scales, kernels, all_population, rate=(0.95, 1.05), children_len=40, max_iter=1000): arch = resnet_nas(consts_choose) consts = constraint_cal(arch) children = [] iter = 0 while True: if len(children) >= children_len or iter >= max_iter: break idx = list(range(len(ga_seed))) random.shuffle(idx) idx_father, idx_mother = idx[:2] father = ga_seed[idx_father] mother = ga_seed[idx_mother] child_fa, child_mo = cross(father, mother) child_fa = mutate(child_fa, stages, channel_scales, kernels, mutate_rate=0.9) child_mo = mutate(child_mo, stages, channel_scales, kernels, mutate_rate=0.9) # 是否满足约束条件 if is_satisfy_consts(consts, child_fa, all_population, rate): children.append(child_fa) if is_satisfy_consts(consts, child_mo, all_population, rate): children.append(child_mo) iter += 1 return children
def is_satisfy_consts(consts, choose, all_choose, rate=(0.6, 1.05)): arch = resnet_nas(choose) if choose not in all_choose: arch_macs, arch_params = constraint_cal(arch) if consts[0] * rate[0] < arch_macs < consts[0] * rate[1] \ and consts[1] * rate[0] < arch_params < consts[1] * rate[1]: return True else: return False else: return False
def main(rank, args): # dist init dist.init_process_group("nccl", init_method='tcp://localhost:12345', rank=rank, world_size=args.world_size) # dataloader train_dir, valid_dir = default_dir[args.dataset] train_trans, valid_trans = hard_trans(args.img_size) trainset = MyDataset(train_dir, transform=train_trans) train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, num_replicas=args.world_size, rank=rank) trainloader = DataLoader(trainset, batch_size=args.bs, shuffle=False, num_workers=8, pin_memory=True, sampler=train_sampler) validset = MyDataset(valid_dir, transform=valid_trans) valid_sampler = torch.utils.data.distributed.DistributedSampler( validset, num_replicas=args.world_size, rank=rank) validloader = DataLoader(validset, batch_size=args.bs, shuffle=False, num_workers=8, pin_memory=True, sampler=valid_sampler) # model # ----------------------修改此处根据步骤2获得结构------------------------------- arch_choose = [[3, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]], [4, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]], [6, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]], [3, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]]] # ---------------------------------------------------------------------------- arch = resnet_nas(arch_choose) if args.arch_dir is not None: arch.load_state_dict(load_normal(args.arch_dir)) print('load success!') fc = models.__dict__[args.fc](trainloader.dataset.n_classes, arch.emb_size, args.fc_scale) if args.fc_dir is not None: fc.load_state_dict(load_normal(args.fc_dir)) criterion = models.__dict__[args.criterion](args.criterion_times) arch = arch.cuda(rank) fc = fc.cuda(rank) criterion = criterion.cuda(rank) # optimizer if torch.cuda.is_available() and args.multi_gpu and torch.cuda.device_count() > 1: args.optim_lr *= round(math.sqrt(torch.cuda.device_count())) optimizer = optimizers.__dict__[args.optim]((arch, fc), args.optim_lr) scheduler = optimizers.__dict__[args.optim_lr_mul](optimizer, args.warmup_epoch, args.max_epoch, len(trainloader)) arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(arch) arch = torch.nn.parallel.DistributedDataParallel(arch, device_ids=[rank], find_unused_parameters=True) fc = torch.nn.parallel.DistributedDataParallel(fc, device_ids=[rank], find_unused_parameters=True) if rank == 0: time_str = datetime.strftime(datetime.now(), '%y-%m-%d-%H-%M-%S') args.log_dir = os.path.join('logs', args.dataset + '_' + time_str) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) logger = logger_init(args.log_dir) if rank != 0: dist.barrier() # train and valid run(args, arch, fc, criterion, optimizer, scheduler, trainloader, validloader, logger)
def main(args): # dist init torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') # dataloader train_dir, valid_dir = default_dir[args.dataset] train_trans, valid_trans = hard_trans(args.img_size) trainset = MyDataset(train_dir, transform=train_trans) train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) trainloader = DataLoader(trainset, batch_size=args.bs, shuffle=False, num_workers=8, pin_memory=True, sampler=train_sampler) validset = MyDataset(valid_dir, transform=valid_trans) valid_sampler = torch.utils.data.distributed.DistributedSampler(validset) validloader = DataLoader(validset, batch_size=args.bs, shuffle=False, num_workers=8, pin_memory=True, sampler=valid_sampler) # model arch = resnet_nas(args.arch_choose) if args.arch_dir is not None: arch.load_state_dict(load_normal(args.arch_dir)) print('load success!') fc = models.__dict__['cos'](trainloader.dataset.n_classes, arch.emb_size, 100.0) criterion = models.__dict__['cross_entropy'](1.0) arch = arch.cuda() fc = fc.cuda() criterion = criterion.cuda() # optimizer optimizer = optimizers.__dict__['sgd']((arch, fc), args.optim_lr) scheduler = optimizers.__dict__['warm_cos'](optimizer, args.warmup_epoch, args.max_epoch, len(trainloader)) # ddp arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(arch) arch = ddp(arch, device_ids=[args.local_rank], find_unused_parameters=True) fc = ddp(fc, device_ids=[args.local_rank], find_unused_parameters=True) # log if args.local_rank == 0: time_str = datetime.strftime(datetime.now(), '%y-%m-%d-%H-%M-%S') args.log_dir = os.path.join('logs', args.dataset + '_' + time_str) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) logger = logger_init(args.log_dir) else: logger = None # train and valid run(args, arch, fc, criterion, optimizer, scheduler, trainloader, validloader, logger)