Ejemplo n.º 1
0
def init_popu(consts_choose,
              stages,
              channel_scales,
              kernels,
              seed_len=40,
              rate=(0.6, 1.05)):
    arch = resnet_nas(consts_choose)
    consts = constraint_cal(arch)

    ga_seed = []
    ga_seed.append(consts_choose)
    while True:
        if len(ga_seed) >= seed_len:
            break

        new_idx = random.choice(list(range(len(ga_seed))))
        new_choose = copy.deepcopy(ga_seed[new_idx])
        new_choose = mutate(new_choose,
                            stages,
                            channel_scales,
                            kernels,
                            mutate_rate=0.8)
        if is_satisfy_consts(consts, new_choose, ga_seed, rate):
            ga_seed.append(new_choose)
    return ga_seed
Ejemplo n.º 2
0
def cross_mutation(consts_choose, ga_seed, stages, channel_scales, kernels, all_population,
                   rate=(0.95, 1.05), children_len=40, max_iter=1000):
    arch = resnet_nas(consts_choose)
    consts = constraint_cal(arch)

    children = []
    iter = 0
    while True:
        if len(children) >= children_len or iter >= max_iter:
            break
        idx = list(range(len(ga_seed)))
        random.shuffle(idx)
        idx_father, idx_mother = idx[:2]
        father = ga_seed[idx_father]
        mother = ga_seed[idx_mother]

        child_fa, child_mo = cross(father, mother)
        child_fa = mutate(child_fa, stages, channel_scales, kernels, mutate_rate=0.9)
        child_mo = mutate(child_mo, stages, channel_scales, kernels, mutate_rate=0.9)

        # 是否满足约束条件
        if is_satisfy_consts(consts, child_fa, all_population, rate):
            children.append(child_fa)
        if is_satisfy_consts(consts, child_mo, all_population, rate):
            children.append(child_mo)

        iter += 1
    return children
Ejemplo n.º 3
0
def is_satisfy_consts(consts, choose, all_choose, rate=(0.6, 1.05)):
    arch = resnet_nas(choose)
    if choose not in all_choose:
        arch_macs, arch_params = constraint_cal(arch)
        if consts[0] * rate[0] < arch_macs < consts[0] * rate[1] \
                and consts[1] * rate[0] < arch_params < consts[1] * rate[1]:
            return True
        else:
            return False
    else:
        return False
Ejemplo n.º 4
0
def main(rank, args):
    # dist init
    dist.init_process_group("nccl", init_method='tcp://localhost:12345', rank=rank, world_size=args.world_size)

    # dataloader
    train_dir, valid_dir = default_dir[args.dataset]
    train_trans, valid_trans = hard_trans(args.img_size)

    trainset = MyDataset(train_dir, transform=train_trans)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, num_replicas=args.world_size, rank=rank)
    trainloader = DataLoader(trainset, batch_size=args.bs, shuffle=False,
                             num_workers=8, pin_memory=True, sampler=train_sampler)

    validset = MyDataset(valid_dir, transform=valid_trans)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        validset, num_replicas=args.world_size, rank=rank)
    validloader = DataLoader(validset, batch_size=args.bs, shuffle=False,
                             num_workers=8, pin_memory=True, sampler=valid_sampler)

    # model
    # ----------------------修改此处根据步骤2获得结构-------------------------------
    arch_choose = [[3, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]],
                   [4, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]],
                   [6, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]],
                   [3, [[1.0, (1, 3)], [1.0, (1, 3)], [1.0, (1, 3)]]]]
    # ----------------------------------------------------------------------------
    arch = resnet_nas(arch_choose)
    if args.arch_dir is not None:
        arch.load_state_dict(load_normal(args.arch_dir))
        print('load success!')
    fc = models.__dict__[args.fc](trainloader.dataset.n_classes, arch.emb_size, args.fc_scale)
    if args.fc_dir is not None:
        fc.load_state_dict(load_normal(args.fc_dir))
    criterion = models.__dict__[args.criterion](args.criterion_times)

    arch = arch.cuda(rank)
    fc = fc.cuda(rank)
    criterion = criterion.cuda(rank)

    # optimizer
    if torch.cuda.is_available() and args.multi_gpu and torch.cuda.device_count() > 1:
        args.optim_lr *= round(math.sqrt(torch.cuda.device_count()))
    optimizer = optimizers.__dict__[args.optim]((arch, fc), args.optim_lr)
    scheduler = optimizers.__dict__[args.optim_lr_mul](optimizer, args.warmup_epoch,
                                                       args.max_epoch, len(trainloader))

    arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(arch)
    arch = torch.nn.parallel.DistributedDataParallel(arch, device_ids=[rank], find_unused_parameters=True)
    fc = torch.nn.parallel.DistributedDataParallel(fc, device_ids=[rank], find_unused_parameters=True)


    if rank == 0:
        time_str = datetime.strftime(datetime.now(), '%y-%m-%d-%H-%M-%S')
        args.log_dir = os.path.join('logs', args.dataset + '_' + time_str)
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        logger = logger_init(args.log_dir)

    if rank != 0:
        dist.barrier()
    # train and valid
    run(args, arch, fc, criterion, optimizer, scheduler, trainloader, validloader, logger)
Ejemplo n.º 5
0
def main(args):
    # dist init
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    # dataloader
    train_dir, valid_dir = default_dir[args.dataset]
    train_trans, valid_trans = hard_trans(args.img_size)

    trainset = MyDataset(train_dir, transform=train_trans)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    trainloader = DataLoader(trainset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=train_sampler)

    validset = MyDataset(valid_dir, transform=valid_trans)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(validset)
    validloader = DataLoader(validset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=valid_sampler)

    # model
    arch = resnet_nas(args.arch_choose)
    if args.arch_dir is not None:
        arch.load_state_dict(load_normal(args.arch_dir))
        print('load success!')
    fc = models.__dict__['cos'](trainloader.dataset.n_classes, arch.emb_size,
                                100.0)
    criterion = models.__dict__['cross_entropy'](1.0)
    arch = arch.cuda()
    fc = fc.cuda()
    criterion = criterion.cuda()

    # optimizer
    optimizer = optimizers.__dict__['sgd']((arch, fc), args.optim_lr)
    scheduler = optimizers.__dict__['warm_cos'](optimizer, args.warmup_epoch,
                                                args.max_epoch,
                                                len(trainloader))

    # ddp
    arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(arch)
    arch = ddp(arch, device_ids=[args.local_rank], find_unused_parameters=True)
    fc = ddp(fc, device_ids=[args.local_rank], find_unused_parameters=True)

    # log
    if args.local_rank == 0:
        time_str = datetime.strftime(datetime.now(), '%y-%m-%d-%H-%M-%S')
        args.log_dir = os.path.join('logs', args.dataset + '_' + time_str)
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        logger = logger_init(args.log_dir)
    else:
        logger = None

    # train and valid
    run(args, arch, fc, criterion, optimizer, scheduler, trainloader,
        validloader, logger)