Exemple #1
0
    def generate_synthetic_data(self, exclude=[], maxn=1000):
        synthetic_encoder_input = []
        synthetic_encoder_target = []
        while len(synthetic_encoder_input) < maxn:
            synthetic_arch = utils.generate_arch(1, self.source_length, self.vocab_size-1)[0]
            synthetic_arch = utils.parse_arch_to_seq(synthetic_arch)
            if synthetic_arch not in exclude and synthetic_arch not in synthetic_encoder_input:
                synthetic_encoder_input.append(synthetic_arch)
    
        nao_synthetic_dataset = NAODataset(synthetic_encoder_input, None, False)      
        nao_synthetic_queue = torch.utils.data.DataLoader(nao_synthetic_dataset, batch_size=len(nao_synthetic_dataset), shuffle=False, pin_memory=True)

        self.eval()
        with torch.no_grad():
            for sample in nao_synthetic_queue:
                encoder_input = sample['encoder_input'].cuda()
                _, _, _, predict_value = self.encoder(encoder_input)
                synthetic_encoder_target += predict_value.data.squeeze().tolist()
        assert len(synthetic_encoder_input) == len(synthetic_encoder_target)
        return synthetic_encoder_input, synthetic_encoder_target
def generate_synthetic_controller_data(model, exclude=[], maxn=1000):
    synthetic_input = []
    synthetic_target = []
    while len(synthetic_input) < maxn:
        synthetic_arch = utils.generate_arch(1, args.layers, args.num_ops)[0]
        if synthetic_arch not in exclude and synthetic_arch not in synthetic_input:
            synthetic_input.append(synthetic_arch)

    synthetic_dataset = utils.ControllerDataset(synthetic_input, None, False)
    synthetic_queue = torch.utils.data.DataLoader(
        synthetic_dataset,
        batch_size=len(synthetic_dataset),
        shuffle=False,
        pin_memory=True)

    with torch.no_grad():
        model.eval()
        for sample in synthetic_queue:
            input = utils.move_to_cuda(sample['encoder_input'])
            _, _, _, predict_value = model.encoder(input)
            synthetic_target += predict_value.data.squeeze().tolist()
    assert len(synthetic_input) == len(synthetic_target)
    return synthetic_input, synthetic_target
Exemple #3
0
def train():
    child_params = get_child_model_params()
    controller_params = get_controller_params()
    corpus = data.Corpus(child_params['data_dir'])
    eval_batch_size = child_params['eval_batch_size']

    train_data = batchify(corpus.train, child_params['batch_size'],
                          child_params['cuda'])
    val_data = batchify(corpus.valid, eval_batch_size, child_params['cuda'])
    ntokens = len(corpus.dictionary)

    if os.path.exists(os.path.join(child_params['model_dir'], 'model.pt')):
        print("Found model.pt in {}, automatically continue training.".format(
            os.path.join(child_params['model_dir'])))
        continue_train_child = True
    else:
        continue_train_child = False

    if continue_train_child:
        child_model = torch.load(
            os.path.join(child_params['model_dir'], 'model.pt'))
    else:
        child_model = model_search.RNNModelSearch(
            ntokens, child_params['emsize'], child_params['nhid'],
            child_params['nhidlast'], child_params['dropout'],
            child_params['dropouth'], child_params['dropoutx'],
            child_params['dropouti'], child_params['dropoute'],
            child_params['drop_path'])

    if os.path.exists(os.path.join(controller_params['model_dir'],
                                   'model.pt')):
        print("Found model.pt in {}, automatically continue training.".format(
            os.path.join(child_params['model_dir'])))
        continue_train_controller = True
    else:
        continue_train_controller = False

    if continue_train_controller:
        controller_model = torch.load(
            os.path.join(controller_params['model_dir'], 'model.pt'))
    else:
        controller_model = controller.Controller(controller_params)

    size = 0
    for p in child_model.parameters():
        size += p.nelement()
    logging.info('child model param size: {}'.format(size))
    size = 0
    for p in controller_model.parameters():
        size += p.nelement()
    logging.info('controller model param size: {}'.format(size))

    if args.cuda:
        if args.single_gpu:
            parallel_child_model = child_model.cuda()
            parallel_controller_model = controller_model.cuda()
        else:
            parallel_child_model = nn.DataParallel(child_model, dim=1).cuda()
            parallel_controller_model = nn.DataParallel(controller_model,
                                                        dim=1).cuda()
    else:
        parallel_child_model = child_model
        parallel_controller_model = controller_model

    total_params = sum(x.data.nelement() for x in child_model.parameters())
    logging.info('Args: {}'.format(args))
    logging.info('Child Model total parameters: {}'.format(total_params))
    total_params = sum(x.data.nelement()
                       for x in controller_model.parameters())
    logging.info('Args: {}'.format(args))
    logging.info('Controller Model total parameters: {}'.format(total_params))

    # Loop over epochs.

    if continue_train_child:
        optimizer_state = torch.load(
            os.path.join(child_params['model_dir'], 'optimizer.pt'))
        if 't0' in optimizer_state['param_groups'][0]:
            child_optimizer = torch.optim.ASGD(
                child_model.parameters(),
                lr=child_params['lr'],
                t0=0,
                lambd=0.,
                weight_decay=child_params['wdecay'])
        else:
            child_optimizer = torch.optim.SGD(
                child_model.parameters(),
                lr=child_params['lr'],
                weight_decay=child_params['wdecay'])
        child_optimizer.load_state_dict(optimizer_state)
        child_epoch = torch.load(
            os.path.join(child_params['model_dir'], 'misc.pt'))['epoch'] - 1
    else:
        child_optimizer = torch.optim.SGD(child_model.parameters(),
                                          lr=child_params['lr'],
                                          weight_decay=child_params['wdecay'])
        child_epoch = 0

    if continue_train_controller:
        optimizer_state = torch.load(
            os.path.join(controller_params['model_dir'], 'optimizer.pt'))
        controller_optimizer = torch.optim.Adam(
            controller_model.parameters(),
            lr=controller_params['lr'],
            weight_decay=controller_params['weight_decay'])
        controller_optimizer.load_state_dict(optimizer_state)
        controller_epoch = torch.load(
            os.path.join(controller_params['model_dir'],
                         'misc.pt'))['epoch'] - 1
    else:
        controller_optimizer = torch.optim.Adam(
            controller_model.parameters(),
            lr=controller_params['lr'],
            weight_decay=controller_params['weight_decay'])
        controller_epoch = 0
    eval_every_epochs = child_params['eval_every_epochs']
    while True:
        # Train child model
        if child_params['arch_pool'] is None:
            arch_pool = generate_arch(
                controller_params['num_seed_arch'])  #[[arch]]
            child_params['arch_pool'] = arch_pool
        child_params['arch'] = None

        if isinstance(eval_every_epochs, int):
            child_params['eval_every_epochs'] = eval_every_epochs
        else:
            eval_every_epochs = list(map(int, eval_every_epochs))
            for index, e in enumerate(eval_every_epochs):
                if child_epoch < e:
                    child_params['eval_every_epochs'] = e
                    break

        for e in range(child_params['eval_every_epochs']):
            child_epoch += 1
            model_search.train(train_data, child_model, parallel_child_model,
                               child_optimizer, child_params, child_epoch)
            if child_epoch % child_params['eval_every_epochs'] == 0:
                save_checkpoint(child_model, child_optimizer, child_epoch,
                                child_params['model_dir'])
                logging.info('Saving Model!')
            if child_epoch >= child_params['train_epochs']:
                break

        # Evaluate seed archs
        valid_accuracy_list = model_search.evaluate(val_data, child_model,
                                                    parallel_child_model,
                                                    child_params,
                                                    eval_batch_size)

        # Output archs and evaluated error rate
        old_archs = child_params['arch_pool']
        old_archs_perf = valid_accuracy_list

        old_archs_sorted_indices = np.argsort(old_archs_perf)
        old_archs = np.array(old_archs)[old_archs_sorted_indices].tolist()
        old_archs_perf = np.array(
            old_archs_perf)[old_archs_sorted_indices].tolist()
        with open(
                os.path.join(child_params['model_dir'],
                             'arch_pool.{}'.format(child_epoch)), 'w') as fa:
            with open(
                    os.path.join(child_params['model_dir'],
                                 'arch_pool.perf.{}'.format(child_epoch)),
                    'w') as fp:
                with open(os.path.join(child_params['model_dir'], 'arch_pool'),
                          'w') as fa_latest:
                    with open(
                            os.path.join(child_params['model_dir'],
                                         'arch_pool.perf'), 'w') as fp_latest:
                        for arch, perf in zip(old_archs, old_archs_perf):
                            arch = ' '.join(map(str, arch))
                            fa.write('{}\n'.format(arch))
                            fa_latest.write('{}\n'.format(arch))
                            fp.write('{}\n'.format(perf))
                            fp_latest.write('{}\n'.format(perf))

        if child_epoch >= child_params['train_epochs']:
            logging.info('Training finished!')
            break

        # Train Encoder-Predictor-Decoder
        # [[arch]]
        encoder_input = list(map(lambda x: parse_arch_to_seq(x), old_archs))
        encoder_target = normalize_target(old_archs_perf)
        decoder_target = copy.copy(encoder_input)
        controller_params['batches_per_epoch'] = math.ceil(
            len(encoder_input) / controller_params['batch_size'])
        controller_epoch = controller.train(encoder_input, encoder_target,
                                            decoder_target, controller_model,
                                            parallel_controller_model,
                                            controller_optimizer,
                                            controller_params,
                                            controller_epoch)

        # Generate new archs
        new_archs = []
        controller_params['predict_lambda'] = 0
        top100_archs = list(
            map(lambda x: parse_arch_to_seq(x), old_archs[:100]))
        max_step_size = controller_params['max_step_size']
        while len(new_archs) < controller_params['max_new_archs']:
            controller_params['predict_lambda'] += 1
            new_arch = controller.infer(top100_archs, controller_model,
                                        parallel_controller_model,
                                        controller_params)
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= controller_params['max_new_archs']:
                    break
            logging.info('{} new archs generated now'.format(len(new_archs)))
            if controller_params['predict_lambda'] >= max_step_size:
                break
        #[[arch]]
        new_archs = list(map(lambda x: parse_seq_to_arch(x),
                             new_archs))  #[[arch]]
        num_new_archs = len(new_archs)
        logging.info("Generate {} new archs".format(num_new_archs))
        random_new_archs = generate_arch(50)
        new_arch_pool = old_archs[:len(old_archs) - num_new_archs -
                                  50] + new_archs + random_new_archs
        logging.info("Totally {} archs now to train".format(
            len(new_arch_pool)))
        child_params['arch_pool'] = new_arch_pool
        with open(os.path.join(child_params['model_dir'], 'arch_pool'),
                  'w') as f:
            for arch in new_arch_pool:
                arch = ' '.join(map(str, arch))
                f.write('{}\n'.format(arch))
Exemple #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
        
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

    args.device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
    args.lr = args.lr
    args.batch_size = args.batch_size
    args.eval_batch_size = args.eval_batch_size
    args.width_stages = [int(val) for val in args.width_stages.split(',')]
    args.n_cell_stages = [int(val) for val in args.n_cell_stages.split(',')]
    args.stride_stages = [int(val) for val in args.stride_stages.split(',')]
    args.num_class = 1000
    args.num_ops = len(utils.OPERATIONS)

    logging.info("args = %s", args)

    feature_name = get_feature_name()

    if args.prune_feature_order == 1:
        prune_func = prune_uni_search_space
    else:
        prune_func = prune_bi_search_space

    params = {
        'boosting_type': 'gbdt',
        'objective': 'regression',
        'metric': {'l2'},
        'num_leaves': args.controller_leaves,
        'learning_rate': args.controller_lr,
        'feature_fraction': 0.9,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': 0
    }
    
    if args.arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(map(lambda x:list(map(int, x.strip().split())), archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_imagenet(None, None, valid_num=5000, epoch=-1)

    if child_arch_pool is None:
        logging.info('Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(args.controller_n, args.layers, args.num_ops)

    arch_pool = []
    arch_pool_valid_acc = []
    child_arch_pool_prob = None
    epoch = 1
    max_num_updates = args.max_num_updates
    pruned_operations = {}
    for controller_iteration in range(args.controller_iterations+1):
        logging.info('Iteration %d', controller_iteration+1)
        num_updates = 0
        while True:
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)
            # sample an arch to train
            train_acc, train_obj, num_updates = child_train(train_queue, model, optimizer, num_updates, child_arch_pool+arch_pool[:200], child_arch_pool_prob, train_criterion)
            epoch += 1
            scheduler.step()
            if num_updates >= max_num_updates:
                break
    
        logging.info("Evaluate seed archs")
        arch_pool += child_arch_pool
        arch_pool_valid_acc = child_valid(valid_queue, model, arch_pool, eval_criterion)

        arch_pool_valid_acc_sorted_indices = np.argsort(arch_pool_valid_acc)[::-1]
        arch_pool = list(map(lambda x:arch_pool[x], arch_pool_valid_acc_sorted_indices))
        arch_pool_valid_acc = list(map(lambda x:arch_pool_valid_acc[x], arch_pool_valid_acc_sorted_indices))
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(controller_iteration)), 'w') as fa:
            with open(os.path.join(args.output_dir, 'arch_pool.perf.{}'.format(controller_iteration)), 'w') as fp:
                for arch, perf in zip(arch_pool, arch_pool_valid_acc):
                    arch = ' '.join(map(str, arch))
                    fa.write('{}\n'.format(arch))
                    fp.write('{}\n'.format(perf))
        if controller_iteration == args.controller_iterations:
            break
                            
        # Train GBDT
        logging.info('Train GBDT')
        inputs = arch_pool
        min_val = min(arch_pool_valid_acc)
        max_val = max(arch_pool_valid_acc)
        targets = list(map(lambda x: (x - min_val) / (max_val - min_val), arch_pool_valid_acc))

        logging.info('Train GBDT')
        gbm = train_controller(params, feature_name, inputs, targets, args.controller_num_boost_round)
        
        prune_func(gbm, inputs, targets, pruned_operations)

        # Ranking sampled candidates
        random_arch = utils.generate_constrained_arch(args.controller_m, args.layers, args.num_ops, pruned_operations)
        logging.info('Totally {} archs sampled from the search space'.format(len(random_arch)))
        random_arch_features = np.array(list(map(utils.convert_to_features, random_arch)))
        random_arch_pred = gbm.predict(random_arch_features, num_iteration=gbm.best_iteration)
        sorted_indices = np.argsort(random_arch_pred)[::-1]
        random_arch = [random_arch[i] for i in sorted_indices]
        new_arch = []
        for arch in random_arch:
            if arch in arch_pool:
                continue
            new_arch.append(arch)
            if len(new_arch) >= args.controller_k:
                break
        #arch_pool += new_arch

        logging.info("Generate %d new archs", len(new_arch))
        child_arch_pool = new_arch #+ arch_pool[:200]

    logging.info('Finish Searching')
Exemple #5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = False
    cudnn.deterministic = True

    args.steps = int(np.ceil(
        45000 / args.child_batch_size)) * args.child_epochs

    logging.info("args = %s", args)

    if args.child_arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.child_arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    elif os.path.exists(os.path.join(args.output_dir, 'arch_pool')):
        logging.info('Architecture pool is founded, loading')
        with open(os.path.join(args.output_dir, 'arch_pool')) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    child_eval_epochs = eval(args.child_eval_epochs)
    build_fn = get_builder(args.dataset)
    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_fn(
        ratio=0.9, epoch=-1)

    nao = NAO(
        args.controller_encoder_layers,
        args.controller_encoder_vocab_size,
        args.controller_encoder_hidden_size,
        args.controller_encoder_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_encoder_emb_size,
        args.controller_mlp_layers,
        args.controller_mlp_hidden_size,
        args.controller_mlp_dropout,
        args.controller_decoder_layers,
        args.controller_decoder_vocab_size,
        args.controller_decoder_hidden_size,
        args.controller_decoder_dropout,
        args.controller_decoder_length,
    )
    nao = nao.cuda()
    logging.info("Encoder-Predictor-Decoder param size = %fMB",
                 utils.count_parameters_in_MB(nao))

    # Train child model
    if child_arch_pool is None:
        logging.info(
            'Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(args.controller_seed_arch,
                                              args.child_nodes,
                                              5)  # [[[conv],[reduc]]]
    if args.child_sample_policy == 'params':
        child_arch_pool_prob = []
        for arch in child_arch_pool:
            if args.dataset == 'cifar10':
                tmp_model = NASNetworkCIFAR(
                    args, 10, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            elif args.dataset == 'cifar100':
                tmp_model = NASNetworkCIFAR(
                    args, 100, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            else:
                tmp_model = NASNetworkImageNet(
                    args, 1000, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            child_arch_pool_prob.append(
                utils.count_parameters_in_MB(tmp_model))
            del tmp_model
    else:
        child_arch_pool_prob = None

    eval_points = utils.generate_eval_points(child_eval_epochs, 0,
                                             args.child_epochs)
    step = 0
    for epoch in range(1, args.child_epochs + 1):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        # sample an arch to train
        train_acc, train_obj, step = child_train(train_queue, model, optimizer,
                                                 step, child_arch_pool,
                                                 child_arch_pool_prob,
                                                 train_criterion)
        logging.info('train_acc %f', train_acc)

        if epoch not in eval_points:
            continue
        # Evaluate seed archs
        valid_accuracy_list = child_valid(valid_queue, model, child_arch_pool,
                                          eval_criterion)

        # Output archs and evaluated error rate
        old_archs = child_arch_pool
        old_archs_perf = valid_accuracy_list

        old_archs_sorted_indices = np.argsort(old_archs_perf)[::-1]
        old_archs = [old_archs[i] for i in old_archs_sorted_indices]
        old_archs_perf = [old_archs_perf[i] for i in old_archs_sorted_indices]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(epoch)),
                  'w') as fa:
            with open(
                    os.path.join(args.output_dir,
                                 'arch_pool.perf.{}'.format(epoch)),
                    'w') as fp:
                with open(os.path.join(args.output_dir, 'arch_pool'),
                          'w') as fa_latest:
                    with open(os.path.join(args.output_dir, 'arch_pool.perf'),
                              'w') as fp_latest:
                        for arch, perf in zip(old_archs, old_archs_perf):
                            arch = ' '.join(map(str, arch[0] + arch[1]))
                            fa.write('{}\n'.format(arch))
                            fa_latest.write('{}\n'.format(arch))
                            fp.write('{}\n'.format(perf))
                            fp_latest.write('{}\n'.format(perf))

        if epoch == args.child_epochs:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Training Encoder-Predictor-Decoder')
        encoder_input = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
                parse_arch_to_seq(x[1], 2), old_archs))
        # [[conv, reduc]]
        min_val = min(old_archs_perf)
        max_val = max(old_archs_perf)
        encoder_target = [(i - min_val) / (max_val - min_val)
                          for i in old_archs_perf]

        if args.controller_expand is not None:
            dataset = list(zip(encoder_input, encoder_target))
            n = len(dataset)
            ratio = 0.9
            split = int(n * ratio)
            np.random.shuffle(dataset)
            encoder_input, encoder_target = list(zip(*dataset))
            train_encoder_input = list(encoder_input[:split])
            train_encoder_target = list(encoder_target[:split])
            valid_encoder_input = list(encoder_input[split:])
            valid_encoder_target = list(encoder_target[split:])
            for _ in range(args.controller_expand - 1):
                for src, tgt in zip(encoder_input[:split],
                                    encoder_target[:split]):
                    a = np.random.randint(0, args.child_nodes)
                    b = np.random.randint(0, args.child_nodes)
                    src = src[:4 * a] + src[4 * a + 2:4 * a + 4] + \
                            src[4 * a:4 * a + 2] + src[4 * (a + 1):20 + 4 * b] + \
                            src[20 + 4 * b + 2:20 + 4 * b + 4] + src[20 + 4 * b:20 + 4 * b + 2] + \
                            src[20 + 4 * (b + 1):]
                    train_encoder_input.append(src)
                    train_encoder_target.append(tgt)
        else:
            train_encoder_input = encoder_input
            train_encoder_target = encoder_target
            valid_encoder_input = encoder_input
            valid_encoder_target = encoder_target
        logging.info('Train data: {}\tValid data: {}'.format(
            len(train_encoder_input), len(valid_encoder_input)))

        nao_train_dataset = utils.NAODataset(
            train_encoder_input,
            train_encoder_target,
            True,
            swap=True if args.controller_expand is None else False)
        nao_valid_dataset = utils.NAODataset(valid_encoder_input,
                                             valid_encoder_target, False)
        nao_train_queue = torch.utils.data.DataLoader(
            nao_train_dataset,
            batch_size=args.controller_batch_size,
            shuffle=True,
            pin_memory=True)
        nao_valid_queue = torch.utils.data.DataLoader(
            nao_valid_dataset,
            batch_size=args.controller_batch_size,
            shuffle=False,
            pin_memory=True)
        nao_optimizer = torch.optim.Adam(nao.parameters(),
                                         lr=args.controller_lr,
                                         weight_decay=args.controller_l2_reg)
        for nao_epoch in range(1, args.controller_epochs + 1):
            nao_loss, nao_mse, nao_ce = nao_train(nao_train_queue, nao,
                                                  nao_optimizer)
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f",
                         nao_epoch, nao_loss, nao_mse, nao_ce)
            if nao_epoch % 100 == 0:
                pa, hs = nao_valid(nao_valid_queue, nao)
                logging.info("Evaluation on valid data")
                logging.info(
                    'epoch %04d pairwise accuracy %.6f hamming distance %.6f',
                    epoch, pa, hs)

        # Generate new archs
        new_archs = []
        max_step_size = 50
        predict_step_size = 0
        top100_archs = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
                parse_arch_to_seq(x[1], 2), old_archs[:100]))
        nao_infer_dataset = utils.NAODataset(top100_archs, None, False)
        nao_infer_queue = torch.utils.data.DataLoader(
            nao_infer_dataset,
            batch_size=len(nao_infer_dataset),
            shuffle=False,
            pin_memory=True)
        while len(new_archs) < args.controller_new_arch:
            predict_step_size += 1
            logging.info('Generate new architectures with step size %d',
                         predict_step_size)
            new_arch = nao_infer(nao_infer_queue,
                                 nao,
                                 predict_step_size,
                                 direction='+')
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= args.controller_new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > max_step_size:
                break
                # [[conv, reduc]]
        new_archs = list(
            map(lambda x: utils.parse_seq_to_arch(x, 2),
                new_archs))  # [[[conv],[reduc]]]
        num_new_archs = len(new_archs)
        logging.info("Generate %d new archs", num_new_archs)
        # replace bottom archs
        if args.controller_replace:
            new_arch_pool = old_archs[:len(old_archs) - (num_new_archs + args.controller_random_arch)] + \
                            new_archs + utils.generate_arch(args.controller_random_arch, 5, 5)
        # discard all archs except top k
        elif args.controller_discard:
            new_arch_pool = old_archs[:100] + new_archs + utils.generate_arch(
                args.controller_random_arch, 5, 5)
        # use all
        else:
            new_arch_pool = old_archs + new_archs + utils.generate_arch(
                args.controller_random_arch, 5, 5)
        logging.info("Totally %d architectures now to train",
                     len(new_arch_pool))

        child_arch_pool = new_arch_pool
        with open(os.path.join(args.output_dir, 'arch_pool'), 'w') as f:
            for arch in new_arch_pool:
                arch = ' '.join(map(str, arch[0] + arch[1]))
                f.write('{}\n'.format(arch))

        if args.child_sample_policy == 'params':
            child_arch_pool_prob = []
            for arch in child_arch_pool:
                if args.dataset == 'cifar10':
                    tmp_model = NASNetworkCIFAR(
                        args, 10, args.child_layers, args.child_nodes,
                        args.child_channels, args.child_keep_prob,
                        args.child_drop_path_keep_prob,
                        args.child_use_aux_head, args.steps, arch)
                elif args.dataset == 'cifar100':
                    tmp_model = NASNetworkCIFAR(
                        args, 100, args.child_layers, args.child_nodes,
                        args.child_channels, args.child_keep_prob,
                        args.child_drop_path_keep_prob,
                        args.child_use_aux_head, args.steps, arch)
                else:
                    tmp_model = NASNetworkImageNet(
                        args, 1000, args.child_layers, args.child_nodes,
                        args.child_channels, args.child_keep_prob,
                        args.child_drop_path_keep_prob,
                        args.child_use_aux_head, args.steps, arch)
                child_arch_pool_prob.append(
                    utils.count_parameters_in_MB(tmp_model))
                del tmp_model
        else:
            child_arch_pool_prob = None
Exemple #6
0
def main():
    arch_pool = utils.generate_arch(args.controller_seed_arch, 5, 5)
    valid_arch_pool = utils.generate_arch(100, 5, 5)
    train_encoder_input = list(
        map(
            lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
            parse_arch_to_seq(x[1], 2), arch_pool))
    valid_encoder_input = list(
        map(
            lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
            parse_arch_to_seq(x[1], 2), valid_arch_pool))
    train_encoder_target = [
        np.random.random() for i in range(args.controller_seed_arch)
    ]
    valid_encoder_target = [np.random.random() for i in range(100)]
    nao = NAO(
        args.controller_encoder_layers,
        args.controller_encoder_vocab_size,
        args.controller_encoder_hidden_size,
        args.controller_encoder_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_encoder_emb_size,
        args.controller_mlp_layers,
        args.controller_mlp_hidden_size,
        args.controller_mlp_dropout,
        args.controller_decoder_layers,
        args.controller_decoder_vocab_size,
        args.controller_decoder_hidden_size,
        args.controller_decoder_dropout,
        args.controller_decoder_length,
    )
    logging.info("param size = %fMB", utils.count_parameters_in_MB(nao))
    nao = nao.cuda()
    nao_train_dataset = utils.NAODataset(train_encoder_input,
                                         train_encoder_target,
                                         True,
                                         swap=True)
    nao_valid_dataset = utils.NAODataset(valid_encoder_input,
                                         valid_encoder_target, False)
    nao_train_queue = torch.utils.data.DataLoader(
        nao_train_dataset,
        batch_size=args.controller_batch_size,
        shuffle=True,
        pin_memory=True)
    nao_valid_queue = torch.utils.data.DataLoader(
        nao_valid_dataset,
        batch_size=len(nao_valid_dataset),
        shuffle=False,
        pin_memory=True)
    nao_optimizer = torch.optim.Adam(nao.parameters(),
                                     lr=args.controller_lr,
                                     weight_decay=args.controller_l2_reg)
    for nao_epoch in range(1, args.controller_epochs + 1):
        nao_loss, nao_mse, nao_ce = nao_train(nao_train_queue, nao,
                                              nao_optimizer)
        if nao_epoch % 10 == 0:
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f",
                         nao_epoch, nao_loss, nao_mse, nao_ce)
        if nao_epoch % 100 == 0:
            pa, hs = nao_valid(nao_valid_queue, nao)
            logging.info("Evaluation on training data")
            logging.info(
                'epoch %04d pairwise accuracy %.6f hamming distance %.6f',
                nao_epoch, pa, hs)

    new_archs = []
    max_step_size = 100
    predict_step_size = 0
    top100_archs = list(
        map(
            lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
            parse_arch_to_seq(x[1], 2), arch_pool[:100]))
    nao_infer_dataset = utils.NAODataset(top100_archs, None, False)
    nao_infer_queue = torch.utils.data.DataLoader(
        nao_infer_dataset,
        batch_size=len(nao_infer_dataset),
        shuffle=False,
        pin_memory=True)
    while len(new_archs) < args.controller_new_arch:
        predict_step_size += 1
        logging.info('Generate new architectures with step size %d',
                     predict_step_size)
        new_arch = nao_infer(nao_infer_queue, nao, predict_step_size)
        for arch in new_arch:
            if arch not in train_encoder_input and arch not in new_archs:
                new_archs.append(arch)
            if len(new_archs) >= args.controller_new_arch:
                break
        logging.info('%d new archs generated now', len(new_archs))
        if predict_step_size > max_step_size:
            break
            # [[conv, reduc]]
    new_archs = list(map(lambda x: utils.parse_seq_to_arch(x, 2),
                         new_archs))  # [[[conv],[reduc]]]
    num_new_archs = len(new_archs)
    logging.info("Generate %d new archs", num_new_archs)
    new_arch_pool = arch_pool + new_archs + utils.generate_arch(
        args.controller_random_arch, 5, 5)
    logging.info("Totally %d archs now to train", len(new_arch_pool))
    arch_pool = new_arch_pool
def main():
    random.seed(args.seed)
    np.random.seed(args.seed)
    
    logging.info("Args = %s", args)

    nasbench = api.NASBench(os.path.join(args.data, 'nasbench_only108.tfrecord'))

    params = {
        'boosting_type': 'gbdt',
        'objective': 'regression',
        'metric': {'l2'},
        'num_leaves': args.leaves,
        'learning_rate': args.lr,
        'feature_fraction': 0.9,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': 0
    }

    mean_val = 0.908192
    std_val = 0.023961
    all_valid_accs = []
    all_test_accs = []
    for i in range(args.num_runs):
        logging.info('{} run'.format(i+1))
        arch_pool, seq_pool, valid_accs = utils.generate_arch(args.n, nasbench, need_perf=True)
        feature_name = utils.get_feature_name()
               
        for ii in range(args.iterations):
            normed_train_perfs = [(i-mean_val)/std_val for i in valid_accs]
            train_x = np.array(seq_pool)
            train_y = np.array(normed_train_perfs)
        
            # Train GBDT-NAS
            lgb_train = lgb.Dataset(train_x, train_y)

            gbm = lgb.train(params, lgb_train, feature_name=feature_name, num_boost_round=args.num_boost_round)

            all_arch, all_seq = utils.generate_arch(None, nasbench)
            logging.info('Totally {} archs from the search space'.format(len(all_seq)))
            all_pred = gbm.predict(np.array(all_seq), num_iteration=gbm.best_iteration)
            sorted_indices = np.argsort(all_pred)[::-1]
            all_arch = [all_arch[i] for i in sorted_indices]
            all_seq = [all_seq[i] for i in sorted_indices]
            new_arch, new_seq, new_valid_acc = [], [], []
            for arch, seq in zip(all_arch, all_seq):
                if seq in seq_pool:
                    continue
                new_arch.append(arch)
                new_seq.append(seq)
                new_valid_acc.append(nasbench.query(arch)['validation_accuracy'])
                if len(new_arch) >= args.k:
                    break
            arch_pool += new_arch
            seq_pool += new_seq
            valid_accs+= new_valid_acc
        
        sorted_indices = np.argsort(valid_accs)[::-1]
        best_arch = arch_pool[sorted_indices[0]]
        best_arch_valid_acc = valid_accs[sorted_indices[0]]
        fs, cs = nasbench.get_metrics_from_spec(best_arch)
        test_acc = np.mean([cs[108][i]['final_test_accuracy'] for i in range(3)])
        all_valid_accs.append(best_arch_valid_acc)
        all_test_accs.append(test_acc)
        logging.info('current valid accuracy: {}'.format(best_arch_valid_acc))
        logging.info('current mean test accuracy: {}'.format(np.mean(test_acc)))
        logging.info('average valid accuracy: {}'.format(np.mean(all_valid_accs)))
        logging.info('average mean test accuracy: {}'.format(np.mean(all_test_accs)))
        logging.info('best valid accuracy: {}'.format(np.max(all_valid_accs)))
        logging.info('best mean test accuracy: {}'.format(np.max(all_test_accs)))

    logging.info('average valid accuracy: {}'.format(np.mean(all_valid_accs)))
    logging.info('average mean test accuracy: {}'.format(np.mean(all_test_accs)))
    logging.info('best valid accuracy: {}'.format(np.max(all_valid_accs)))
    logging.info('best mean test accuracy: {}'.format(np.max(all_test_accs)))
Exemple #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

    if args.dataset == 'cifar10':
        args.num_class = 10
    elif args.dataset == 'cifar100':
        args.num_class = 100
    else:
        args.num_class = 10

    if args.search_space == 'small':
        OPERATIONS = OPERATIONS_search_small
    elif args.search_space == 'middle':
        OPERATIONS = OPERATIONS_search_middle
    args.child_num_ops = len(OPERATIONS)
    args.controller_encoder_vocab_size = 1 + (args.child_nodes + 2 -
                                              1) + args.child_num_ops
    args.controller_decoder_vocab_size = args.controller_encoder_vocab_size
    args.steps = int(np.ceil(
        45000 / args.child_batch_size)) * args.child_epochs

    logging.info("args = %s", args)

    if args.child_arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.child_arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    elif os.path.exists(os.path.join(args.output_dir, 'arch_pool')):
        logging.info('Architecture pool is founded, loading')
        with open(os.path.join(args.output_dir, 'arch_pool')) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    build_fn = get_builder(args.dataset)
    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_fn(
        ratio=0.9, epoch=-1)

    nao = NAO(
        args.controller_encoder_layers,
        args.controller_encoder_vocab_size,
        args.controller_encoder_hidden_size,
        args.controller_encoder_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_encoder_emb_size,
        args.controller_mlp_layers,
        args.controller_mlp_hidden_size,
        args.controller_mlp_dropout,
        args.controller_decoder_layers,
        args.controller_decoder_vocab_size,
        args.controller_decoder_hidden_size,
        args.controller_decoder_dropout,
        args.controller_decoder_length,
    )
    nao = nao.cuda()
    logging.info("Encoder-Predictor-Decoder param size = %fMB",
                 utils.count_parameters_in_MB(nao))

    if child_arch_pool is None:
        logging.info(
            'Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(
            args.controller_seed_arch, args.child_nodes,
            args.child_num_ops)  # [[[conv],[reduc]]]
    arch_pool = []
    arch_pool_valid_acc = []
    for i in range(4):
        logging.info('Iteration %d', i)

        child_arch_pool_prob = []
        for arch in child_arch_pool:
            if args.dataset == 'cifar10':
                tmp_model = NASNetworkCIFAR(
                    args, args.num_class, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            elif args.dataset == 'cifar100':
                tmp_model = NASNetworkCIFAR(
                    args, args.num_class, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            else:
                tmp_model = NASNetworkImageNet(
                    args, args.num_class, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            child_arch_pool_prob.append(
                utils.count_parameters_in_MB(tmp_model))
            del tmp_model

        step = 0
        scheduler = get_scheduler(optimizer, args.dataset)
        for epoch in range(1, args.child_epochs + 1):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)
            # sample an arch to train
            train_acc, train_obj, step = child_train(train_queue, model,
                                                     optimizer, step,
                                                     child_arch_pool,
                                                     child_arch_pool_prob,
                                                     train_criterion)
            logging.info('train_acc %f', train_acc)

        logging.info("Evaluate seed archs")
        arch_pool += child_arch_pool
        arch_pool_valid_acc = child_valid(valid_queue, model, arch_pool,
                                          eval_criterion)

        arch_pool_valid_acc_sorted_indices = np.argsort(
            arch_pool_valid_acc)[::-1]
        arch_pool = [arch_pool[i] for i in arch_pool_valid_acc_sorted_indices]
        arch_pool_valid_acc = [
            arch_pool_valid_acc[i] for i in arch_pool_valid_acc_sorted_indices
        ]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(i)),
                  'w') as fa:
            with open(
                    os.path.join(args.output_dir,
                                 'arch_pool.perf.{}'.format(i)), 'w') as fp:
                for arch, perf in zip(arch_pool, arch_pool_valid_acc):
                    arch = ' '.join(map(str, arch[0] + arch[1]))
                    fa.write('{}\n'.format(arch))
                    fp.write('{}\n'.format(perf))
        if i == 3:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Train Encoder-Predictor-Decoder')
        encoder_input = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0]) + utils.
                parse_arch_to_seq(x[1]), arch_pool))
        # [[conv, reduc]]
        min_val = min(arch_pool_valid_acc)
        max_val = max(arch_pool_valid_acc)
        encoder_target = [(i - min_val) / (max_val - min_val)
                          for i in arch_pool_valid_acc]

        if args.controller_expand:
            dataset = list(zip(encoder_input, encoder_target))
            n = len(dataset)
            ratio = 0.9
            split = int(n * ratio)
            np.random.shuffle(dataset)
            encoder_input, encoder_target = list(zip(*dataset))
            train_encoder_input = list(encoder_input[:split])
            train_encoder_target = list(encoder_target[:split])
            valid_encoder_input = list(encoder_input[split:])
            valid_encoder_target = list(encoder_target[split:])
            for _ in range(args.controller_expand - 1):
                for src, tgt in zip(encoder_input[:split],
                                    encoder_target[:split]):
                    a = np.random.randint(0, args.child_nodes)
                    b = np.random.randint(0, args.child_nodes)
                    src = src[:4 * a] + src[4 * a + 2:4 * a + 4] + \
                            src[4 * a:4 * a + 2] + src[4 * (a + 1):20 + 4 * b] + \
                            src[20 + 4 * b + 2:20 + 4 * b + 4] + src[20 + 4 * b:20 + 4 * b + 2] + \
                            src[20 + 4 * (b + 1):]
                    train_encoder_input.append(src)
                    train_encoder_target.append(tgt)
        else:
            train_encoder_input = encoder_input
            train_encoder_target = encoder_target
            valid_encoder_input = encoder_input
            valid_encoder_target = encoder_target
        logging.info('Train data: {}\tValid data: {}'.format(
            len(train_encoder_input), len(valid_encoder_input)))

        nao_train_dataset = utils.NAODataset(
            train_encoder_input,
            train_encoder_target,
            True,
            swap=True if args.controller_expand is None else False)
        nao_valid_dataset = utils.NAODataset(valid_encoder_input,
                                             valid_encoder_target, False)
        nao_train_queue = torch.utils.data.DataLoader(
            nao_train_dataset,
            batch_size=args.controller_batch_size,
            shuffle=True,
            pin_memory=True)
        nao_valid_queue = torch.utils.data.DataLoader(
            nao_valid_dataset,
            batch_size=args.controller_batch_size,
            shuffle=False,
            pin_memory=True)
        nao_optimizer = torch.optim.Adam(nao.parameters(),
                                         lr=args.controller_lr,
                                         weight_decay=args.controller_l2_reg)
        for nao_epoch in range(1, args.controller_epochs + 1):
            nao_loss, nao_mse, nao_ce = nao_train(nao_train_queue, nao,
                                                  nao_optimizer)
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f",
                         nao_epoch, nao_loss, nao_mse, nao_ce)
            if nao_epoch % 100 == 0:
                pa, hs = nao_valid(nao_valid_queue, nao)
                logging.info("Evaluation on valid data")
                logging.info(
                    'epoch %04d pairwise accuracy %.6f hamming distance %.6f',
                    nao_epoch, pa, hs)

        # Generate new archs
        new_archs = []
        max_step_size = 50
        predict_step_size = 0
        top100_archs = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0]) + utils.
                parse_arch_to_seq(x[1]), arch_pool[:100]))
        nao_infer_dataset = utils.NAODataset(top100_archs, None, False)
        nao_infer_queue = torch.utils.data.DataLoader(
            nao_infer_dataset,
            batch_size=len(nao_infer_dataset),
            shuffle=False,
            pin_memory=True)
        while len(new_archs) < args.controller_new_arch:
            predict_step_size += 1
            logging.info('Generate new architectures with step size %d',
                         predict_step_size)
            new_arch = nao_infer(nao_infer_queue,
                                 nao,
                                 predict_step_size,
                                 direction='+')
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= args.controller_new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > max_step_size:
                break

        child_arch_pool = list(
            map(lambda x: utils.parse_seq_to_arch(x),
                new_archs))  # [[[conv],[reduc]]]
        logging.info("Generate %d new archs", len(child_arch_pool))

    logging.info('Finish Searching')
    logging.info('Reranking top 5 architectures')
    # reranking top 5
    top_archs = arch_pool[:5]
    if args.dataset == 'cifar10':
        top_archs_perf = train_and_evaluate_top_on_cifar10(
            top_archs, train_queue, valid_queue)
    elif args.dataset == 'cifar100':
        top_archs_perf = train_and_evaluate_top_on_cifar100(
            top_archs, train_queue, valid_queue)
    else:
        top_archs_perf = train_and_evaluate_top_on_imagenet(
            top_archs, train_queue, valid_queue)
    top_archs_sorted_indices = np.argsort(top_archs_perf)[::-1]
    top_archs = [top_archs[i] for i in top_archs_sorted_indices]
    top_archs_perf = [top_archs_perf[i] for i in top_archs_sorted_indices]
    with open(os.path.join(args.output_dir, 'arch_pool.final'), 'w') as fa:
        with open(os.path.join(args.output_dir, 'arch_pool.perf.final'),
                  'w') as fp:
            for arch, perf in zip(top_archs, top_archs_perf):
                arch = ' '.join(map(str, arch[0] + arch[1]))
                fa.write('{}\n'.format(arch))
                fp.write('{}\n'.format(perf))
Exemple #9
0
def train():
    child_params = get_child_model_params()
    controller_params = get_controller_params()
    branch_length = controller_params['source_length'] // 2 // 5 // 2
    eval_every_epochs = child_params['eval_every_epochs']
    child_epoch = 0
    while True:
        # Train child model
        if child_params['arch_pool'] is None:
            arch_pool = utils.generate_arch(controller_params['num_seed_arch'],
                                            child_params['num_cells'],
                                            5)  #[[[conv],[reduc]]]
            child_params['arch_pool'] = arch_pool
            child_params['arch_pool_prob'] = None
        else:
            if child_params['sample_policy'] == 'uniform':
                child_params['arch_pool_prob'] = None
            elif child_params['sample_policy'] == 'params':
                child_params['arch_pool_prob'] = calculate_params(
                    child_params['arch_pool'])
            elif child_params['sample_policy'] == 'valid_performance':
                child_params['arch_pool_prob'] = child_valid(child_params)
            elif child_params['sample_policy'] == 'predicted_performance':
                encoder_input = list(map(lambda x: utils.parse_arch_to_seq(x[0], branch_length) + \
                                                   utils.parse_arch_to_seq(x[1], branch_length), child_params['arch_pool']))
                predicted_error_rate = controller.test(controller_params,
                                                       encoder_input)
                child_params['arch_pool_prob'] = [
                    1 - i[0] for i in predicted_error_rate
                ]
            else:
                raise ValueError(
                    'Child model arch pool sample policy is not provided!')

        if isinstance(eval_every_epochs, int):
            child_params['eval_every_epochs'] = eval_every_epochs
        else:
            for index, e in enumerate(eval_every_epochs):
                if child_epoch < e:
                    child_params['eval_every_epochs'] = e
                    break

        child_epoch = child_train(child_params)

        # Evaluate seed archs
        valid_accuracy_list = child_valid(child_params)

        # Output archs and evaluated error rate
        old_archs = child_params['arch_pool']
        old_archs_perf = [1 - i for i in valid_accuracy_list]

        old_archs_sorted_indices = np.argsort(old_archs_perf)
        old_archs = np.array(old_archs)[old_archs_sorted_indices].tolist()
        old_archs_perf = np.array(
            old_archs_perf)[old_archs_sorted_indices].tolist()
        with open(
                os.path.join(child_params['model_dir'],
                             'arch_pool.{}'.format(child_epoch)), 'w') as fa:
            with open(
                    os.path.join(child_params['model_dir'],
                                 'arch_pool.perf.{}'.format(child_epoch)),
                    'w') as fp:
                with open(os.path.join(child_params['model_dir'], 'arch_pool'),
                          'w') as fa_latest:
                    with open(
                            os.path.join(child_params['model_dir'],
                                         'arch_pool.perf'), 'w') as fp_latest:
                        for arch, perf in zip(old_archs, old_archs_perf):
                            arch = ' '.join(map(str, arch[0] + arch[1]))
                            fa.write('{}\n'.format(arch))
                            fa_latest.write('{}\n'.format(arch))
                            fp.write('{}\n'.format(perf))
                            fp_latest.write('{}\n'.format(perf))

        if child_epoch >= child_params['num_epochs']:
            break

        # Train Encoder-Predictor-Decoder
        encoder_input = list(map(lambda x : utils.parse_arch_to_seq(x[0], branch_length) + \
                                          utils.parse_arch_to_seq(x[1], branch_length), old_archs))
        #[[conv, reduc]]
        min_val = min(old_archs_perf)
        max_val = max(old_archs_perf)
        encoder_target = [(i - min_val) / (max_val - min_val)
                          for i in old_archs_perf]
        decoder_target = copy.copy(encoder_input)
        controller_params['batches_per_epoch'] = math.ceil(
            len(encoder_input) / controller_params['batch_size'])
        #if clean controller model
        controller.train(controller_params, encoder_input, encoder_target,
                         decoder_target)

        # Generate new archs
        #old_archs = old_archs[:450]
        new_archs = []
        max_step_size = 100
        controller_params['predict_lambda'] = 0
        top100_archs = list(map(lambda x : utils.parse_arch_to_seq(x[0], branch_length) + \
                                          utils.parse_arch_to_seq(x[1], branch_length), old_archs[:100]))
        while len(new_archs) < 500:
            controller_params['predict_lambda'] += 1
            new_arch = controller.predict(controller_params, top100_archs)
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= 500:
                    break
            tf.logging.info('{} new archs generated now'.format(
                len(new_archs)))
            if controller_params['predict_lambda'] > max_step_size:
                break
                #[[conv, reduc]]
        new_archs = list(
            map(lambda x: utils.parse_seq_to_arch(x, branch_length),
                new_archs))  #[[[conv],[reduc]]]
        num_new_archs = len(new_archs)
        tf.logging.info("Generate {} new archs".format(num_new_archs))
        new_arch_pool = old_archs[:len(old_archs) - (
            num_new_archs + 50)] + new_archs + utils.generate_arch(50, 5, 5)
        tf.logging.info("Totally {} archs now to train".format(
            len(new_arch_pool)))
        child_params['arch_pool'] = new_arch_pool
        with open(os.path.join(child_params['model_dir'], 'arch_pool'),
                  'w') as f:
            for arch in new_arch_pool:
                arch = ' '.join(map(str, arch[0] + arch[1]))
                f.write('{}\n'.format(arch))
def main(args):
    logging.info('training on {} gpus'.format(torch.cuda.device_count()))
    logging.info('max tokens {} per gpu'.format(args.max_tokens))
    logging.info('max sentences {} per gpu'.format(
        args.max_sentences if args.max_sentences is not None else 'None'))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

    args.child_num_ops = len(operations.OPERATIONS_ENCODER)
    args.controller_vocab_size = 1 + args.child_num_ops
    args.controller_source_length = args.enc_layers + args.dec_layers
    args.controller_encoder_length = args.controller_decoder_length = args.controller_source_length
    tasks.set_ljspeech_hparams(args)
    logging.info("args = {}".format(args))

    if args.arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.arch_pool) as f:
            archs = f.read().splitlines()
            child_arch_pool = [
                list(map(int,
                         arch.strip().split())) for arch in archs
            ]
    else:
        child_arch_pool = None

    task = tasks.LJSpeechTask(args)
    task.setup_task(ws=True)
    logging.info("Model param size = %d", utils.count_parameters(task.model))

    controller = nao_controller.NAO(args)
    logging.info("Encoder-Predictor-Decoder param size = %d",
                 utils.count_parameters(controller))

    if child_arch_pool is None:
        logging.info(
            'Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(
            args.controller_seed_arch, args.enc_layers + args.dec_layers,
            args.child_num_ops)

    arch_pool = []
    num_updates = 0
    r_c = list(map(float, args.reward.strip().split()))
    for controller_iteration in range(args.controller_iterations + 1):
        logging.info('Iteration %d', controller_iteration + 1)
        arch_pool += child_arch_pool
        for epoch in range(1, args.max_epochs * len(arch_pool) + 1):
            decoder_loss, stop_loss, loss, num_updates = task.train(
                epoch=epoch, num_updates=num_updates, arch_pool=arch_pool)
            lr = task.scheduler.get_lr()
            logging.info(
                'epoch %d updates %d lr %.6f decoder loss %.6f stop loss %.6f loss %.6f',
                epoch, num_updates, lr, decoder_loss, stop_loss, loss)
        frs, pcrs, dfrs, losses = task.valid_for_search(
            None, gta=False, arch_pool=arch_pool, layer_norm_training=True)
        reward = [
            r_c[0] * fr + r_c[1] * pcr + r_c[2] * dfr - r_c[3] * loss
            for fr, pcr, dfr, loss in zip(frs, pcrs, dfrs, losses)
        ]
        arch_pool_valid_perf = reward

        arch_pool_valid_perf_sorted_indices = np.argsort(
            arch_pool_valid_perf)[::-1]
        arch_pool = list(
            map(lambda x: arch_pool[x], arch_pool_valid_perf_sorted_indices))
        arch_pool_valid_perf = list(
            map(lambda x: arch_pool_valid_perf[x],
                arch_pool_valid_perf_sorted_indices))
        os.makedirs(os.path.join(args.output_dir), exist_ok=True)
        with open(
                os.path.join(args.output_dir,
                             'arch_pool.{}'.format(controller_iteration)),
                'w') as fa:
            with open(
                    os.path.join(
                        args.output_dir,
                        'arch_pool.perf.{}'.format(controller_iteration)),
                    'w') as fp:
                for arch, perf in zip(arch_pool, arch_pool_valid_perf):
                    arch = ' '.join(map(str, arch))
                    fa.write('{}\n'.format(arch))
                    fp.write('{}\n'.format(perf))
        if controller_iteration == args.controller_iterations:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Train Encoder-Predictor-Decoder')
        inputs = list(map(lambda x: utils.parse_arch_to_seq(x), arch_pool))
        min_val = min(arch_pool_valid_perf)
        max_val = max(arch_pool_valid_perf)
        targets = list(
            map(lambda x: (x - min_val) / (max_val - min_val),
                arch_pool_valid_perf))

        # Pre-train NAO
        logging.info('Pre-train EPD')
        controller.build_dataset('train', inputs, targets, True)
        controller.build_queue('train')
        for epoch in range(1, args.controller_pretrain_epochs + 1):
            loss, mse, ce = controller.train_epoch('train')
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f", epoch,
                         loss, mse, ce)
        logging.info('Finish pre-training EPD')

        # Generate synthetic data
        logging.info('Generate synthetic data for EPD')
        synthetic_inputs, synthetic_targets = controller.generate_synthetic_data(
            inputs, args.controller_random_arch)
        if args.controller_up_sample_ratio:
            all_inputs = inputs * args.controller_up_sample_ratio + synthetic_inputs
            all_targets = targets * args.controller_up_sample_ratio + synthetic_targets
        else:
            all_inputs = inputs + synthetic_inputs
            all_targets = targets + synthetic_targets
        # Train NAO
        logging.info('Train EPD')
        controller.build_dataset('train', all_inputs, all_targets, True)
        controller.build_queue('train')
        for epoch in range(1, args.controller_epochs + 1):
            loss, mse, ce = controller.train_epoch('train')
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f", epoch,
                         loss, mse, ce)
        logging.info('Finish training EPD')

        # Generate new archs
        new_archs = []
        max_step_size = 100
        predict_step_size = 0
        # get top 100 from true data and synthetic data
        topk_indices = np.argsort(all_targets)[:100]
        topk_archs = list(map(lambda x: all_inputs[x], topk_indices))
        controller.build_dataset('infer', topk_archs, None, False)
        controller.build_queue('infer',
                               batch_size=len(topk_archs),
                               shuffle=False)
        while len(new_archs) < args.controller_new_arch:
            predict_step_size += 0.1
            logging.info('Generate new architectures with step size %.2f',
                         predict_step_size)
            new_arch = controller.infer('infer',
                                        predict_step_size,
                                        direction='+')
            for arch in new_arch:
                if arch not in inputs and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= args.controller_new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > max_step_size:
                break

        child_arch_pool = list(
            map(lambda x: utils.parse_seq_to_arch(x), new_archs))
        logging.info("Generate %d new archs", len(child_arch_pool))

    logging.info('Finish Searching')
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

    args.device_count = torch.cuda.device_count() if torch.cuda.is_available(
    ) else 1
    args.lr = args.lr * args.device_count
    args.batch_size = args.batch_size * args.device_count
    args.eval_batch_size = args.eval_batch_size * args.device_count
    args.width_stages = [int(val) for val in args.width_stages.split(',')]
    args.n_cell_stages = [int(val) for val in args.n_cell_stages.split(',')]
    args.stride_stages = [int(val) for val in args.stride_stages.split(',')]
    args.num_class = 1000
    args.num_ops = len(utils.OPERATIONS)
    args.controller_vocab_size = 1 + args.num_ops
    args.controller_source_length = args.layers
    args.controller_encoder_length = args.layers
    args.controller_decoder_length = args.controller_source_length

    logging.info("args = %s", args)

    if args.arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(
                map(lambda x: list(map(int,
                                       x.strip().split())), archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_imagenet(
        None, None, valid_num=5000, epoch=-1)

    controller = NAO(
        args.controller_encoder_layers,
        args.controller_mlp_layers,
        args.controller_decoder_layers,
        args.controller_vocab_size,
        args.controller_hidden_size,
        args.controller_mlp_hidden_size,
        args.controller_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_decoder_length,
    )
    controller = controller.cuda()
    logging.info("Encoder-Predictor-Decoder param size = %d",
                 utils.count_parameters(controller))

    if child_arch_pool is None:
        logging.info(
            'Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(args.controller_seed_arch,
                                              args.layers, args.num_ops)

    arch_pool = []
    arch_pool_valid_acc = []
    for controller_iteration in range(args.controller_iterations + 1):
        logging.info('Iteration %d', controller_iteration + 1)

        child_arch_pool_prob = None
        num_updates = 0
        max_num_updates = args.max_num_updates
        epoch = 1
        while True:
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)
            # sample an arch to train
            train_acc, train_obj, num_updates = child_train(
                train_queue, model, optimizer, num_updates, child_arch_pool,
                child_arch_pool_prob, train_criterion)
            epoch += 1
            scheduler.step()
            if num_updates >= max_num_updates:
                break

        logging.info("Evaluate seed archs")
        arch_pool += child_arch_pool
        arch_pool_valid_acc = child_valid(valid_queue, model, arch_pool,
                                          eval_criterion)

        arch_pool_valid_acc_sorted_indices = np.argsort(
            arch_pool_valid_acc)[::-1]
        arch_pool = list(
            map(lambda x: arch_pool[x], arch_pool_valid_acc_sorted_indices))
        arch_pool_valid_acc = list(
            map(lambda x: arch_pool_valid_acc[x],
                arch_pool_valid_acc_sorted_indices))
        with open(
                os.path.join(args.output_dir,
                             'arch_pool.{}'.format(controller_iteration)),
                'w') as fa:
            with open(
                    os.path.join(
                        args.output_dir,
                        'arch_pool.perf.{}'.format(controller_iteration)),
                    'w') as fp:
                for arch, perf in zip(arch_pool, arch_pool_valid_acc):
                    arch = ' '.join(map(str, arch))
                    fa.write('{}\n'.format(arch))
                    fp.write('{}\n'.format(perf))
        if controller_iteration == args.controller_iterations:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Train Encoder-Predictor-Decoder')
        inputs = arch_pool
        min_val = min(arch_pool_valid_acc)
        max_val = max(arch_pool_valid_acc)
        targets = list(
            map(lambda x: (x - min_val) / (max_val - min_val),
                arch_pool_valid_acc))

        # Pre-train
        logging.info('Pre-train EPD')
        train_controller(controller, inputs, targets,
                         args.controller_pretrain_epochs)
        logging.info('Finish pre-training EPD')
        # Generate synthetic data
        logging.info('Generate synthetic data for EPD')
        synthetic_inputs, synthetic_targets = generate_synthetic_controller_data(
            controller, inputs, args.controller_random_arch)
        if args.controller_up_sample_ratio:
            all_inputs = inputs * args.controller_up_sample_ratio + synthetic_inputs
            all_targets = targets * args.controller_up_sample_ratio + synthetic_targets
        else:
            all_inputs = inputs + synthetic_inputs
            all_targets = targets + synthetic_targets
        # Train
        logging.info('Train EPD')
        train_controller(controller, all_inputs, all_targets,
                         args.controller_epochs)
        logging.info('Finish training EPD')

        # Generate new archs
        new_archs = []
        max_step_size = 100
        predict_step_size = 0.0
        topk_indices = np.argsort(all_targets)[:100]
        topk_archs = list(map(lambda x: all_inputs[x], topk_indices))
        infer_dataset = utils.ControllerDataset(topk_archs, None, False)
        infer_queue = torch.utils.data.DataLoader(
            infer_dataset,
            batch_size=len(infer_dataset),
            shuffle=False,
            pin_memory=True)
        while len(new_archs) < args.controller_new_arch:
            predict_step_size += 0.1
            logging.info('Generate new architectures with step size %.2f',
                         predict_step_size)
            new_arch = controller_infer(infer_queue,
                                        controller,
                                        predict_step_size,
                                        direction='+')
            for arch in new_arch:
                if arch not in inputs and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= args.controller_new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > max_step_size:
                break

        child_arch_pool = new_archs
        logging.info("Generate %d new archs", len(child_arch_pool))

    logging.info('Finish Searching')
    logging.info('Reranking top 5 architectures')
    # reranking top 5
    top_archs = arch_pool[:5]
    top_archs_perf = train_and_evaluate(top_archs, train_queue, valid_queue)
    top_archs_sorted_indices = np.argsort(top_archs_perf)[::-1]
    top_archs = [top_archs[i] for i in top_archs_sorted_indices]
    top_archs_perf = [top_archs_perf[i] for i in top_archs_sorted_indices]
    with open(os.path.join(args.output_dir, 'arch_pool.final'), 'w') as fa:
        with open(os.path.join(args.output_dir, 'arch_pool.perf.final'),
                  'w') as fp:
            for arch, perf in zip(top_archs, top_archs_perf):
                arch = ' '.join(map(str, arch))
                fa.write('{}\n'.format(arch))
                fp.write('{}\n'.format(perf))
Exemple #12
0
def main():
    random.seed(args.seed)
    np.random.seed(args.seed)

    args.steps = int(np.ceil(45000 / args.child_batch_size)) * args.child_epochs

    logging.info("args = %s", args)

    if args.child_arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.child_arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_arch, archs))
            child_arch_pool = archs
    elif os.path.exists(os.path.join(args.output_dir, 'arch_pool')):
        logging.info('Architecture pool is founded, loading')
        with open(os.path.join(args.output_dir, 'arch_pool')) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_arch, archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    child_eval_epochs = eval(args.child_eval_epochs)
    build_fn = get_builder(args.dataset)
    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_fn(ratio=0.9, epoch=-1)

    nao = NAO(
        args.controller_encoder_layers,
        args.controller_encoder_vocab_size,
        args.controller_encoder_hidden_size,
        args.controller_encoder_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_encoder_emb_size,
        args.controller_mlp_layers,
        args.controller_mlp_hidden_size,
        args.controller_mlp_dropout,
        args.controller_decoder_layers,
        args.controller_decoder_vocab_size,
        args.controller_decoder_hidden_size,
        args.controller_decoder_dropout,
        args.controller_decoder_length,
    )
    nao = nao.cuda()
    logging.info("Encoder-Predictor-Decoder param size = %fMB", utils.count_parameters_in_MB(nao))

    # Train child model
    if child_arch_pool is None:
        logging.info('Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(args.controller_seed_arch, args.child_nodes, 5)  # [[[conv],[reduc]]]
    child_arch_pool_prob = None

    eval_points = utils.generate_eval_points(child_eval_epochs, 0, args.child_epochs)
    step = 0
    for epoch in range(1, args.child_epochs + 1):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        # sample an arch to train
        train_acc, train_obj, step = child_train(train_queue, model, optimizer, step, child_arch_pool, child_arch_pool_prob, train_criterion)
        logging.info('train_acc %f', train_acc)

        if epoch not in eval_points:
            continue
        # Evaluate seed archs
        valid_accuracy_list = child_valid(valid_queue, model, child_arch_pool, eval_criterion)

        # Output archs and evaluated error rate
        old_archs = child_arch_pool
        old_archs_perf = valid_accuracy_list

        old_archs_sorted_indices = np.argsort(old_archs_perf)[::-1]
        old_archs = [old_archs[i] for i in old_archs_sorted_indices]
        old_archs_perf = [old_archs_perf[i] for i in old_archs_sorted_indices]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(epoch)), 'w') as fa:
            with open(os.path.join(args.output_dir, 'arch_pool.perf.{}'.format(epoch)), 'w') as fp:
                with open(os.path.join(args.output_dir, 'arch_pool'), 'w') as fa_latest:
                    with open(os.path.join(args.output_dir, 'arch_pool.perf'), 'w') as fp_latest:
                        for arch, perf in zip(old_archs, old_archs_perf):
                            arch = ' '.join(map(str, arch[0] + arch[1]))
                            fa.write('{}\n'.format(arch))
                            fa_latest.write('{}\n'.format(arch))
                            fp.write('{}\n'.format(perf))
                            fp_latest.write('{}\n'.format(perf))

        if epoch == args.child_epochs:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Training Encoder-Predictor-Decoder')
        encoder_input = list \
            (map(lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.parse_arch_to_seq(x[1], 2), old_archs))
        # [[conv, reduc]]
        min_val = min(old_archs_perf)
        max_val = max(old_archs_perf)
        encoder_target = [(i - min_val) / (max_val - min_val) for i in old_archs_perf]

        if args.controller_expand:
            dataset = list(zip(encoder_input, encoder_target))
            n = len(dataset)
            ratio = 0.9
            split = int( n *ratio)
            np.random.shuffle(dataset)
            encoder_input, encoder_target = list(zip(*dataset))
            train_encoder_input = list(encoder_input[:split])
            train_encoder_target = list(encoder_target[:split])
            valid_encoder_input = list(encoder_input[split:])
            valid_encoder_target = list(encoder_target[split:])
            for _ in range(args.controller_expan d -1):
                for src, tgt in zip(encoder_input[:split], encoder_target[:split]):
                    a = np.random.randint(0, args.child_nodes)
                    b = np.random.randint(0, args.child_nodes)
                    src = src[:4 * a] + src[4 * a + 2:4 * a + 4] + \
                          src[4 * a:4 * a + 2] + src[4 * (a + 1):20 + 4 * b] + \
                          src[20 + 4 * b + 2:20 + 4 * b + 4] + src[20 + 4 * b:20 + 4 * b + 2] + \
                          src[20 + 4 * (b + 1):]
                    train_encoder_input.append(src)
                    train_encoder_target.append(tgt)
        else:
            train_encoder_input = encoder_input
            train_encoder_target = encoder_target
            valid_encoder_input = encoder_input
            valid_encoder_target = encoder_target
        logging.info('Train data: {}\tValid data: {}'.format(len(train_encoder_input), len(valid_encoder_input)))

        nao_train_dataset = utils.NAODataset(train_encoder_input, train_encoder_target, True, swap=True if args.controller_expand is None else False)
        nao_valid_dataset = utils.NAODataset(valid_encoder_input, valid_encoder_target, False)
        nao_train_queue = torch.utils.data.DataLoader(
            nao_train_dataset, batch_size=args.controller_batch_size, shuffle=True, pin_memory=True)
        nao_valid_queue = torch.utils.data.DataLoader(
            nao_valid_dataset, batch_size=args.controller_batch_size, shuffle=False, pin_memory=True)
        nao_optimizer = torch.optim.Adam(nao.parameters(), lr=args.controller_lr, weight_decay=args.controller_l2_reg)
        for nao_epoch in range(1, args.controller_epoch s +1):
            nao_loss, nao_mse, nao_ce = nao_train(nao_train_queue, nao, nao_optimizer)
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f", nao_epoch, nao_loss, nao_mse, nao_ce)
            if nao_epoch % 100 == 0:
                pa, hs = nao_valid(nao_valid_queue, nao)
                logging.info("Evaluation on valid data")
                logging.info('epoch %04d pairwise accuracy %.6f hamming distance %.6f', epoch, pa, hs)
Exemple #13
0
def main():
    random.seed(args.seed)
    np.random.seed(args.seed)
    
    logging.info("Args = %s", args)

    nasbench = api.NASBench(os.path.join(args.data, 'nasbench_only108.tfrecord'))

    arch_pool, seq_pool, valid_accs = utils.generate_arch(args.n, nasbench, need_perf=True)
    feature_name = utils.get_feature_name()
    params = {
        'boosting_type': 'gbdt',
        'objective': 'regression',
        'metric': {'l2'},
        'num_leaves': args.leaves,
        'learning_rate': args.lr,
        'feature_fraction': 0.9,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': 0
    }

    sorted_indices = np.argsort(valid_accs)[::-1]
    arch_pool = [arch_pool[i] for i in sorted_indices]
    seq_pool = [seq_pool[i] for i in sorted_indices]
    valid_accs = [valid_accs[i] for i in sorted_indices]
    with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(0)), 'w') as f:
        for arch, seq, valid_acc in zip(arch_pool, seq_pool, valid_accs):
            f.write('{}\t{}\t{}\t{}\n'.format(arch.matrix, arch.ops, seq, valid_acc))
    mean_val = 0.908192
    std_val = 0.023961
    for i in range(args.iterations):
        logging.info('Iteration {}'.format(i+1))
        normed_valid_accs = [(i-mean_val)/std_val for i in valid_accs]
        train_x = np.array(seq_pool)
        train_y = np.array(normed_valid_accs)
        
        # Train GBDT-NAS
        logging.info('Train GBDT-NAS')
        lgb_train = lgb.Dataset(train_x, train_y)

        gbm = lgb.train(params, lgb_train, feature_name=feature_name, num_boost_round=args.num_boost_round)
        gbm.save_model(os.path.join(args.output_dir, 'model.txt'))
    
        all_arch, all_seq = utils.generate_arch(None, nasbench)
        logging.info('Totally {} archs from the search space'.format(len(all_seq)))
        all_pred = gbm.predict(np.array(all_seq), num_iteration=gbm.best_iteration)
        sorted_indices = np.argsort(all_pred)[::-1]
        all_arch = [all_arch[i] for i in sorted_indices]
        all_seq = [all_seq[i] for i in sorted_indices]
        new_arch, new_seq, new_valid_acc = [], [], []
        for arch, seq in zip(all_arch, all_seq):
            if seq in seq_pool:
                continue
            new_arch.append(arch)
            new_seq.append(seq)
            new_valid_acc.append(nasbench.query(arch)['validation_accuracy'])
            if len(new_arch) >= args.k:
                break
        arch_pool += new_arch
        seq_pool += new_seq
        valid_accs += new_valid_acc
        
        sorted_indices = np.argsort(valid_accs)[::-1]
        arch_pool = [arch_pool[i] for i in sorted_indices]
        seq_pool = [seq_pool[i] for i in sorted_indices]
        valid_accs = [valid_accs[i] for i in sorted_indices]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(i+1)), 'w') as f:
            for arch, seq, va in zip(arch_pool, seq_pool, valid_accs):
                f.write('{}\t{}\t{}\t{}\n'.format(arch.matrix, arch.ops, seq, va))

    logging.info('Finish Searching\n')    
    with open(os.path.join(args.output_dir, 'arch_pool.final'), 'w') as f:
        for i in range(10):
            arch, seq, valid_acc = arch_pool[i], seq_pool[i], valid_accs[i]
            fs, cs = nasbench.get_metrics_from_spec(arch)
            test_acc = np.mean([cs[108][i]['final_test_accuracy'] for i in range(3)])
            f.write('{}\t{}\t{}\t{}\t{}\n'.format(arch.matrix, arch.ops, seq, valid_acc, test_acc))
            print('{}\t{}\tvalid acc: {}\tmean test acc: {}\n'.format(arch.matrix, arch.ops, valid_acc, test_acc))