Example #1
0
def main():
    args = get_args()
    assert args.exp_name is not None
    splits = []
    start = 0
    k = 800
    end = k
    for i in range(5):
        splits += [(start, end)]
        start += k
        end += k
    print(splits)
    # cands = generate_cand_list(12000)
    # pickle.dump(cands, open( "../data/cl3_2_1.p", "wb" ) )

    # candidate_list = pickle.load(open("../data/cl3_2_1.p", "rb"))
    candidate_list = pickle.load(open("/home/bg141/SinglePathOneShot/src/data/loc_data_ed_15.p", "rb"))

    # candidate_list = [np.fromstring(c[1:-1], dtype=int, sep=',').tolist() for c in candidate_list]


    model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes)
    model = nn.DataParallel(model)
    device = torch.device("cuda")    
    # model = model.to(device)
    model = model.cuda()
    lastest_model, iters = get_lastest_model(args.exp_name)
    print("Iters: ", iters)
    if lastest_model is not None:
        all_iters = iters
        checkpoint = torch.load(lastest_model)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        print('load from checkpoint')

    err_list = []
    cand_list = []
    print("Split: ", splits[args.gpu])
    i = 0
    for cand in candidate_list[splits[args.gpu][0]:splits[args.gpu][1]]:
        err = get_cand_err(model, cand, args)
        err_list += [err]
        cand_list += [cand]
        i += 1
        print("Net: ", i)
        # if i%500 == 0:
        #     pickle.dump(err_list, open("./data/err-"+args.exp_name+"-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
        #     pickle.dump(cand_list, open("./data/cand-"+args.exp_name+"-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
    
    pickle.dump(err_list, open( "./data/err-"+args.exp_name+"-"+str(args.gpu)+"_ed_15.p", "wb" ) )
    pickle.dump(cand_list, open("./data/cand-"+args.exp_name+"-"+str(args.gpu)+"_ed_15.p", "wb"))

    print("Finished")
    return
Example #2
0
def main3():
    args = get_args()
    assert args.exp_name is not None
    if not os.path.exists('./data/' + args.exp_name + "/"):
        os.mkdir('./data/' + args.exp_name )
    # Build candidate list
    get_random_cand = lambda:tuple(np.random.randint(4) for i in range(20))
    flops_l, flops_r, flops_step = 290, 360, 50
    bins = [[i, i+flops_step] for i in range(flops_l, flops_r, flops_step)]

    def get_uniform_sample_cand(*,timeout=500):
        idx = np.random.randint(len(bins))
        l, r = bins[idx]
        for i in range(timeout):
            cand = get_random_cand()
            if l*1e6 <= get_cand_flops(cand) <= r*1e6:
                return cand
        print("timeout")
        return get_random_cand()

    model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes)
    model = nn.DataParallel(model)
    device = torch.device("cuda")    
    # model = model.to(device)
    model = model.cuda()
    lastest_model, iters = get_lastest_model(args.exp_name)
    if lastest_model is not None:
        all_iters = iters
        checkpoint = torch.load(lastest_model)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        print('load from checkpoint')

    err_list = []
    cand_list = []
    i = 0
    print("GPU: ", args.gpu)
    for i in range(5000):
        cand = get_uniform_sample_cand()
        err = get_cand_err(model, cand, args)
        err_list += [err]
        cand_list += [cand]
        print("Net: ", i)
        if i%500 == 0:
            pickle.dump(err_list, open("./data/"+args.exp_name+"/err-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
            pickle.dump(cand_list, open("./data/"+args.exp_name+"/cand-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
    
    pickle.dump(err_list, open("./data/"+args.exp_name+"/err-"+str(args.gpu)+".p", "wb"))
    pickle.dump(cand_list, open("./data/"+args.exp_name+"/cand-"+str(args.gpu)+".p", "wb"))

    print("Finished")
    return
Example #3
0
def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    if args.cifar10 == False:

        assert os.path.exists(args.train_dir)
        train_dataset = datasets.ImageFolder(
            args.train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(brightness=0.4,
                                       contrast=0.4,
                                       saturation=0.4),
                transforms.RandomHorizontalFlip(0.5),
                ToBGRTensor(),
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=1,
                                                   pin_memory=use_gpu)
        train_dataprovider = DataIterator(train_loader)

        assert os.path.exists(args.val_dir)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            args.val_dir,
            transforms.Compose([
                OpencvResize(256),
                transforms.CenterCrop(224),
                ToBGRTensor(),
            ])),
                                                 batch_size=200,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=use_gpu)
        val_dataprovider = DataIterator(val_loader)
        print('load imagenet data successfully')

    else:
        train_transform, valid_transform = data_transforms(args)

        trainset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_dir, 'cifar'),
                                                train=True,
                                                download=True,
                                                transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=8)
        train_dataprovider = DataIterator(train_loader)
        valset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_dir, 'cifar'),
                                              train=False,
                                              download=True,
                                              transform=valid_transform)
        val_loader = torch.utils.data.DataLoader(valset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=8)
        val_dataprovider = DataIterator(val_loader)

        print('load cifar10 data successfully')

    model = ShuffleNetV2_OneShot()

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda step: (1.0 - step / args.total_iters)
        if step <= args.total_iters else 0,
        last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model,
                          device,
                          args,
                          val_interval=args.val_interval,
                          bn_process=False,
                          all_iters=all_iters)
Example #4
0
def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.im_size),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=8, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(args.im_size),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=8, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    arch_path='arch.pkl'

    if os.path.exists(arch_path):
        with open(arch_path,'rb') as f:
            architecture=pickle.load(f)
    else:
        raise NotImplementedError
    channels_scales = (1.0,)*20
    model = ShuffleNetV2_OneShot(architecture=architecture, channels_scales=channels_scales, n_class=args.num_classes, input_size=args.im_size)

    print('flops:',get_flops(model))

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(args.num_classes, 0.1)

    if use_gpu:
        # model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    # model = model.to(device)
    model = model.cuda()

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, device, args, all_iters=all_iters)
        exit(0)
    t = time.time()
    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    # all_iters = train(model, device, args, val_interval=int(1280000/args.batch_size), bn_process=True, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')
    print("Finished {} iters in {:.3f} seconds".format(all_iters, time.time()-t))
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    val_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root="./data",
        train=False,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True)

    print('load data successfully')

    model = mutableResNet20(10)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    model = model.to(device)
    print("load model successfully")

    all_iters = 0
    print('load from latest checkpoint')
    lastest_model, iters = get_lastest_model()
    if lastest_model is not None:
        all_iters = iters
        checkpoint = torch.load(lastest_model,
                                map_location=None if use_gpu else 'cpu')
        model.load_state_dict(checkpoint['state_dict'], strict=True)

    # 参数设置
    args.loss_function = loss_function
    args.val_dataloader = val_loader

    print("start to validate model")

    validate(model, device, args, all_iters=all_iters, arch_loader=arch_loader)
Example #6
0
def main():
    
    
    #LOAD CONFIGS################################################################
    args = get_args()
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_no
    
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    
    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True
    
#     cudnn.enabled=True
    torch.cuda.manual_seed(str(args.rand_seed))
    random.seed(args.rand_seed) 
    #LOAD DATA###################################################################
    def convert_param(original_lists):
      ctype, value = original_lists[0], original_lists[1]
      is_list = isinstance(value, list)
      if not is_list: value = [value]
      outs = []
      for x in value:
        if ctype == 'int':
          x = int(x)
        elif ctype == 'str':
          x = str(x)
        elif ctype == 'bool':
          x = bool(int(x))
        elif ctype == 'float':
          x = float(x)
        elif ctype == 'none':
          if x.lower() != 'none':
            raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
          x = None
        else:
          raise TypeError('Does not know this type : {:}'.format(ctype))
        outs.append(x)
      if not is_list: outs = outs[0]
      return outs

    if args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std  = [x / 255 for x in [68.2, 65.4, 70.4]]
        lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
        transform_train = transforms.Compose(lists)
        transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

        with open('../data/cifar-split.txt', 'r') as f:
            data = json.load(f)
            content = { k: convert_param(v) for k,v in data.items()}
            Arguments = namedtuple('Configure', ' '.join(content.keys()))
            content   = Arguments(**content)

        cifar_split = content
        train_split, valid_split = cifar_split.train, cifar_split.valid
    
        print(len(train_split),len(valid_split))
    
        train_dataset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
    
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=4, pin_memory=use_gpu)

        train_dataprovider = DataIterator(train_loader)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_test),
            batch_size=250, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=4, pin_memory=use_gpu
        )

        val_dataprovider = DataIterator(val_loader)
        print('load data successfully')
        CLASS = 100
    elif args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std  = [x / 255 for x in [63.0, 62.1, 66.7]]
        lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
        transform_train = transforms.Compose(lists)
        transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        with open('../data/cifar-split.txt', 'r') as f:
            data = json.load(f)
            content = { k: convert_param(v) for k,v in data.items()}
            Arguments = namedtuple('Configure', ' '.join(content.keys()))
            content   = Arguments(**content)

        cifar_split = content
        train_split, valid_split = cifar_split.train, cifar_split.valid

        print(len(train_split),len(valid_split))

        train_dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)


        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=4, pin_memory=use_gpu)

        train_dataprovider = DataIterator(train_loader)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_test),
            batch_size=250, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=4, pin_memory=use_gpu
        )

        val_dataprovider = DataIterator(val_loader)
        print('load data successfully')
        CLASS = 10
    elif args.dataset == 'image16':
        mean = [x / 255 for x in [122.68, 116.66, 104.01]]
        std  = [x / 255 for x in [63.22,  61.26 , 65.09]]
        transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        with open('../data/ImageNet16-120-split.txt', 'r') as f:
            data = json.load(f)
            content = { k: convert_param(v) for k,v in data.items()}
            Arguments = namedtuple('Configure', ' '.join(content.keys()))
            content   = Arguments(**content)
        img_split = content
        train_split, valid_split = img_split.train, img_split.valid
        train_split = train_split[:len(train_split)//args.batch_size*args.batch_size]
        valid_split = valid_split[:len(valid_split)//250*250]
        print(len(train_split),len(valid_split))
        train_dataset = ImageNet16('../data', True , transform_test,120)
        test_dataset  = ImageNet16('../data', True, transform_test,120)
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=4, pin_memory=use_gpu)

        train_dataprovider = DataIterator(train_loader)

        val_loader = torch.utils.data.DataLoader(
                test_dataset,
            batch_size=250, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=4, pin_memory=use_gpu
        )

        val_dataprovider = DataIterator(val_loader)
        
        print('load data successfully')
        CLASS = 120
        
    print(CLASS)
    print(args.init_channels,args.stacks//3)
    model = TinyNetwork(C=args.init_channels,N=args.stacks//3,max_nodes = 4, num_classes = CLASS, search_space = NAS_BENCH_201, affine = False, track_running_stats = False).cuda()
    
    
    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    
    criterion_smooth = CrossEntropyLabelSmooth(CLASS, 0.1)
    
    if use_gpu:

        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda" )
        
    else:
        
        loss_function = criterion_smooth
        device = torch.device("cpu")
        
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,args.total_iters)
    model = model.to(device)
    
    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    
    
    
    
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider
    
    args.evo_controller = evolutionary(args.max_population,args.select_number, args.mutation_len,args.mutation_number,args.p_opwise,args.evo_momentum)
    
    
    
    
    
    path = './record_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(args.dataset,args.stacks,args.init_channels,args.total_iters,args.warmup_iters,args.max_population,args.select_number,args.mutation_len,args.mutation_number,args.val_interval,args.p_opwise,args.evo_momentum,args.rand_seed)
    
    logging.info(path)
    
    model.current_N = 1
    
    while all_iters < args.total_iters:
        
        if all_iters in [15000,30000,45000,60000]:
#         if all_iters in [50,100,150,200]:
#             print("----------")
            model.current_N += 1
        
        if all_iters > 1 and all_iters % args.val_interval == 0:
            results = []
            for structure_father in args.evo_controller.group:
                results.append([structure_father.structure,structure_father.loss,structure_father.count])
            if not os.path.exists(path):
                os.mkdir(path)
                
            with open(path + '/%06d-ep.txt'%all_iters,'w') as tt:
                json.dump(results,tt)
            
            
            if all_iters >= args.warmup_iters:#warmup
                args.evo_controller.select()
                
                
            
        
            
            
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
        
    results = []
    for structure_father in args.evo_controller.group:
        results.append([structure_father.structure,structure_father.loss,structure_father.count])
    with open(path + '/%06d-ep.txt'%all_iters,'w') as tt:
        json.dump(results,tt)
def main():
    args = get_args()
    num_gpus = torch.cuda.device_count()
    args.gpu = args.local_rank % num_gpus
    torch.cuda.set_device(args.gpu)

    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    args.batch_size = args.batch_size // args.world_size

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers, args.total_iters)

    val_loader = get_val_loader(args.batch_size, args.num_workers)

    model = mutableResNet20()

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        # model = nn.DataParallel(model)
        model = model.cuda(args.gpu)
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        loss_function = criterion_smooth.cuda()
    else:
        loss_function = criterion_smooth

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)

    all_iters = 0

    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, args, all_iters=all_iters, arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup > 0:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model, args, all_iters=all_iters, arch_loader=arch_loader)

    while all_iters < args.total_iters:
        logging.info("=" * 50)
        all_iters = train_subnet(model,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)

        if all_iters % 200 == 0 and args.local_rank == 0:
            logging.info("validate iter {}".format(all_iters))

            validate(model, args, all_iters=all_iters, arch_loader=arch_loader)
Example #8
0
def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(96),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            args.val_dir,
            transforms.Compose([
                OpencvResize(96),
                # transforms.CenterCrop(96),
                ToBGRTensor(),
            ])),
        batch_size=200,
        shuffle=False,
        num_workers=4,
        pin_memory=use_gpu)
    val_dataprovider = DataIterator(val_loader)

    arch_path = 'cl400.p'

    if os.path.exists(arch_path):
        with open(arch_path, 'rb') as f:
            architectures = pickle.load(f)
    else:
        raise NotImplementedError
    channels_scales = (1.0, ) * 20
    cands = {}
    splits = [(i, 10 + i) for i in range(0, 400, 10)]
    architectures = np.array(architectures)
    architectures = architectures[
        splits[args.split_num][0]:splits[args.split_num][1]]
    print(len(architectures))
    logging.info("Training and Validating arch: " +
                 str(splits[args.split_num]))
    for architecture in architectures:
        architecture = tuple(architecture.tolist())
        model = ShuffleNetV2_OneShot(architecture=architecture,
                                     channels_scales=channels_scales,
                                     n_class=10,
                                     input_size=96)

        print('flops:', get_flops(model))

        optimizer = torch.optim.SGD(get_parameters(model),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

        if use_gpu:
            model = nn.DataParallel(model)
            loss_function = criterion_smooth.cuda()
            device = torch.device("cuda")
        else:
            loss_function = criterion_smooth
            device = torch.device("cpu")

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: (1.0 - step / args.total_iters)
            if step <= args.total_iters else 0,
            last_epoch=-1)

        model = model.to(device)

        all_iters = 0
        if args.auto_continue:
            lastest_model, iters = get_lastest_model()
            if lastest_model is not None:
                all_iters = iters
                checkpoint = torch.load(
                    lastest_model, map_location=None if use_gpu else 'cpu')
                model.load_state_dict(checkpoint['state_dict'], strict=True)
                print('load from checkpoint')
                for i in range(iters):
                    scheduler.step()

        args.optimizer = optimizer
        args.loss_function = loss_function
        args.scheduler = scheduler
        args.train_dataprovider = train_dataprovider
        args.val_dataprovider = val_dataprovider
        # print("BEGIN VALDATE: ", args.eval, args.eval_resume)
        if args.eval:
            if args.eval_resume is not None:
                checkpoint = torch.load(
                    args.eval_resume, map_location=None if use_gpu else 'cpu')
                model.load_state_dict(checkpoint, strict=True)
                validate(model, device, args, all_iters=all_iters)
            exit(0)
        # t1,t5 = validate(model, device, args, all_iters=all_iters)
        # print("VALDATE: ", t1, "   ", t5)

        while all_iters < args.total_iters:
            all_iters = train(model,
                              device,
                              args,
                              val_interval=args.val_interval,
                              bn_process=False,
                              all_iters=all_iters)
            validate(model, device, args, all_iters=all_iters)
        all_iters = train(model,
                          device,
                          args,
                          val_interval=int(1280000 / args.batch_size),
                          bn_process=True,
                          all_iters=all_iters)
        top1, top5 = validate(model, device, args, all_iters=all_iters)
        save_checkpoint({
            'state_dict': model.state_dict(),
        },
                        args.total_iters,
                        tag='bnps-')
        cands[architecture] = [top1, top5]
        pickle.dump(
            cands,
            open("from_scratch_split_{}.pkl".format(args.split_num), 'wb'))
Example #9
0
def pipeline(args, reporter):
    # Log for one Supernet
    floder = '{}/task_id_{}'.format(args.signal, args.task_id)
    path = os.path.join(arg.local, 'save', floder)
    if not os.path.isdir(path):
        os.makedirs(path)
    args.path = path

    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)

    if not os.path.exists('{}/log'.format(path)):
        os.mkdir('{}/log'.format(path))
    fh = logging.FileHandler(
        os.path.join('{}/log/{}-task_id{}-train-{}{:02}{}'.format(path, args['signal'], args['task_id'], local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info('{}-task_id: {}'.format(args.signal, args.task_id))

    # resource
    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    # load dataset
    if args.num_classes==10:
        dataset_train, dataset_valid = dataset_cifar.get_dataset("cifar10", N=args.randaug_n, M=args.randaug_m, RandA=args.RandA)
    elif args.num_classes==100:
        dataset_train, dataset_valid = dataset_cifar.get_dataset("cifar100", N=args.randaug_n, M=args.randaug_m, RandA=args.RandA)

    split = 0.0
    split_idx = 0
    train_sampler = None
    if split > 0.0:
        sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0)
        sss = sss.split(list(range(len(dataset_train))), dataset_train.targets)
        for _ in range(split_idx + 1):
            train_idx, valid_idx = next(sss)
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx)
    else:
        valid_sampler = SubsetSampler([])

    train_loader = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size, shuffle=True if train_sampler is None else False, num_workers=32,
        pin_memory=True,
        sampler=train_sampler, drop_last=True)

    # valid_loader = torch.utils.data.DataLoader(
    #     dataset_train, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True,
    #     sampler=valid_sampler, drop_last=False)

    #
    valid_loader = torch.utils.data.DataLoader(
        dataset_valid, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True,
        drop_last=False)

    train_dataprovider = DataIterator(train_loader)
    val_dataprovider = DataIterator(valid_loader)
    args.test_interval = len(valid_loader)
    args.val_interval = int(len(dataset_train) / args.batch_size) # step
    print('load data successfully')

    # network
    if args.block == 5:
        model = ShuffleNetV2_OneShot_cifar(block=args['block'], n_class=args.num_classes)
    elif args.block == 12:
        model = SuperNetwork(shadow_bn=True, layers=args['block'], classes=args.num_classes)
        print("param size = %fMB" % count_parameters_in_MB(model))
    elif args.block == 4:
        model = Network(num_classes=args.num_classes) # model = Network(net()).to(device).half()
    elif args.block == 3:
        model = Network_cifar(num_classes=args.num_classes)


    # lr and parameters
    # original optimizer lr & wd

    # test lr_range
    # args.learning_rate = args.learning_rate * (args['task_id']+ 1)

    # parameters divided into groups
    # test shuffle lr_group (4 stage * 5choice + 1base_lr == 21)
    # test mobile lr_group (12 stage * 12 choice + 1base_lr == 145)
    # test fast lr_group (3 stage * 1 choice + 1base_lr == 4)

    # lr_group = [i/100 for i in list(range(4,25,1))]
    # arch_search = list(np.random.randint(2) for i in range(5*2))
    # optimizer = torch.optim.SGD(get_dif_lr_parameters(model, lr_group, arch_search),

    if args.different_hpo:
        if args['block']==5:
            nums_lr_group = args['block'] * args['choice'] + 1
            lr_group = list(np.random.uniform(0.4, 0.8) for i in range(nums_lr_group))
            optimizer = torch.optim.SGD(shuffle_dif_lr_parameters(model, lr_group),
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        elif args['block']==12:
            nums_lr_group=145
            lr_group = list(np.random.uniform(0.1, 0.3) for i in range(nums_lr_group))
            optimizer = torch.optim.SGD(mobile_dif_lr_parameters(model, lr_group),
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

        elif args['block']==4:
            nums_lr_group=4
            lr_l, lr_r = float(arg.lr_range.split(',')[0]), float(arg.lr_range.split(',')[1])
            lr_group = list(np.random.uniform(lr_l, lr_r) for i in range(nums_lr_group))
            optimizer = torch.optim.SGD(fast_dif_lr_parameters(model, lr_group),
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

        elif args['block'] == 3:
            nums_lr_group = 19 # 9 * 2 + 1
            lr_l, lr_r = float(arg.lr_range.split(',')[0]), float(arg.lr_range.split(',')[1])
            lr_group = list(np.random.uniform(lr_l, lr_r) for i in range(nums_lr_group))
            optimizer = torch.optim.SGD(fast_19_lr_parameters(model, lr_group),
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            # log lr
            # for param_group in optimizer.param_groups:
            #     print(param_group['lr'])

            # save optim
            # torch.save(optimizer.state_dict(),'optimizer.pt')
            # optimizer.load_state_dict(torch.load('optimizer.pt'))

    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.learning_rate, # without hpo / glboal hpo
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # optimizer = torch.optim.SGD(get_parameters(model),
    #                             lr=args.learning_rate,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    # lookahead optimizer
    # base_opt = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
    # optimizer = Lookahead(base_opt, k=5, alpha=0.5)

    # blockly optimizer
    # base_opt_2 = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
    # base_opt_3 = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999))
    # base_opt_group = [base_opt, base_opt_2, base_opt_3]
    # optimizer = BlocklyOptimizer(base_opt_group, k=5, alpha=0.5)

    # loss func, ls=0.1
    criterion_smooth = CrossEntropyLabelSmooth(10, args['label_smooth'])

    # lr_scheduler is related to total_iters
    scheduler = torch.optim.lr_scheduler.LambdaLR \
        (optimizer, lambda step: (1.0 - step / args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, float(args.total_iters / args.val_interval), eta_min=1e-8, last_epoch=-1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    model = model.to(device)

    all_iters = 0
    if args.auto_continue: # load model
        lastest_model, iters = get_lastest_model(args.path)
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step() # lr Align

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler

    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            # model.load_state_dict(checkpoint, strict=True)
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    # according to total_iters
    while all_iters < args.total_iters:
        all_iters, Top1_acc = \
            train(model, device, args, bn_process=True, all_iters=all_iters, reporter=reporter)
Example #10
0
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_dataset, val_dataset = get_dataset('cifar100')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True)
    # train_dataprovider = DataIterator(train_loader)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=200,
                                             shuffle=False,
                                             num_workers=12,
                                             pin_memory=True)

    # val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = mutableResNet20()

    print('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda step: (1.0 - step / args.total_iters)
        if step <= args.total_iters else 0,
        last_epoch=-1)

    model = model.to(device)

    # dp_model = torch.nn.parallel.DistributedDataParallel(model)

    all_iters = 0
    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader
    # args.train_dataprovider = train_dataprovider
    # args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model,
                          device,
                          args,
                          val_interval=args.val_interval,
                          bn_process=False,
                          all_iters=all_iters,
                          arch_loader=arch_loader,
                          arch_batch=args.arch_batch)
Example #11
0
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    kwargs = {'num_workers': 4, 'pin_memory': True}

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root="./data",
        train=False,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    model = mutableResNet20(num_classes=10)
    base_model = copy.deepcopy(model)

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
        base_model.cuda()
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
    #                                               lambda step: (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, T_max=200)

    model = model.to(device)

    all_iters = 0

    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup is not None:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       device,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model,
                 device,
                 args,
                 all_iters=all_iters,
                 arch_loader=arch_loader)

    while all_iters < args.total_iters:
        all_iters = train_subnet(model,
                                 base_model,
                                 device,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)
        logging.info("validate iter {}".format(all_iters))

        if all_iters % 9 == 0:
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)

    validate(model, device, args, all_iters=all_iters, arch_loader=arch_loader)