Beispiel #1
0
def build_cifar100(model_state_dict, optimizer_state_dict, **kwargs):
    epoch = kwargs.pop('epoch')

    train_transform, valid_transform = utils._data_transforms_cifar10(
        args.cutout_size)
    train_data = dset.CIFAR100(root=args.data,
                               train=True,
                               download=True,
                               transform=train_transform)
    valid_data = dset.CIFAR100(root=args.data,
                               train=False,
                               download=True,
                               transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=16)
    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.eval_batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=16)

    model = NASNetworkCIFAR(args, 100, args.layers, args.nodes, args.channels,
                            args.keep_prob, args.drop_path_keep_prob,
                            args.use_aux_head, args.steps, args.arch)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    logging.info("multi adds = %fM", model.multi_adds / 1000000)
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    if torch.cuda.device_count() > 1:
        logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
        model = nn.DataParallel(model)
    model = model.cuda()

    train_criterion = nn.CrossEntropyLoss().cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr_max,
        momentum=0.9,
        weight_decay=args.l2_reg,
    )

    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), args.lr_min, epoch)
    return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
Beispiel #2
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = False
    cudnn.deterministic = True

    args.steps = int(np.ceil(
        45000 / args.child_batch_size)) * args.child_epochs

    logging.info("args = %s", args)

    if args.child_arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.child_arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    elif os.path.exists(os.path.join(args.output_dir, 'arch_pool')):
        logging.info('Architecture pool is founded, loading')
        with open(os.path.join(args.output_dir, 'arch_pool')) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    child_eval_epochs = eval(args.child_eval_epochs)
    build_fn = get_builder(args.dataset)
    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_fn(
        ratio=0.9, epoch=-1)

    nao = NAO(
        args.controller_encoder_layers,
        args.controller_encoder_vocab_size,
        args.controller_encoder_hidden_size,
        args.controller_encoder_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_encoder_emb_size,
        args.controller_mlp_layers,
        args.controller_mlp_hidden_size,
        args.controller_mlp_dropout,
        args.controller_decoder_layers,
        args.controller_decoder_vocab_size,
        args.controller_decoder_hidden_size,
        args.controller_decoder_dropout,
        args.controller_decoder_length,
    )
    nao = nao.cuda()
    logging.info("Encoder-Predictor-Decoder param size = %fMB",
                 utils.count_parameters_in_MB(nao))

    # Train child model
    if child_arch_pool is None:
        logging.info(
            'Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(args.controller_seed_arch,
                                              args.child_nodes,
                                              5)  # [[[conv],[reduc]]]
    if args.child_sample_policy == 'params':
        child_arch_pool_prob = []
        for arch in child_arch_pool:
            if args.dataset == 'cifar10':
                tmp_model = NASNetworkCIFAR(
                    args, 10, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            elif args.dataset == 'cifar100':
                tmp_model = NASNetworkCIFAR(
                    args, 100, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            else:
                tmp_model = NASNetworkImageNet(
                    args, 1000, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            child_arch_pool_prob.append(
                utils.count_parameters_in_MB(tmp_model))
            del tmp_model
    else:
        child_arch_pool_prob = None

    eval_points = utils.generate_eval_points(child_eval_epochs, 0,
                                             args.child_epochs)
    step = 0
    for epoch in range(1, args.child_epochs + 1):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        # sample an arch to train
        train_acc, train_obj, step = child_train(train_queue, model, optimizer,
                                                 step, child_arch_pool,
                                                 child_arch_pool_prob,
                                                 train_criterion)
        logging.info('train_acc %f', train_acc)

        if epoch not in eval_points:
            continue
        # Evaluate seed archs
        valid_accuracy_list = child_valid(valid_queue, model, child_arch_pool,
                                          eval_criterion)

        # Output archs and evaluated error rate
        old_archs = child_arch_pool
        old_archs_perf = valid_accuracy_list

        old_archs_sorted_indices = np.argsort(old_archs_perf)[::-1]
        old_archs = [old_archs[i] for i in old_archs_sorted_indices]
        old_archs_perf = [old_archs_perf[i] for i in old_archs_sorted_indices]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(epoch)),
                  'w') as fa:
            with open(
                    os.path.join(args.output_dir,
                                 'arch_pool.perf.{}'.format(epoch)),
                    'w') as fp:
                with open(os.path.join(args.output_dir, 'arch_pool'),
                          'w') as fa_latest:
                    with open(os.path.join(args.output_dir, 'arch_pool.perf'),
                              'w') as fp_latest:
                        for arch, perf in zip(old_archs, old_archs_perf):
                            arch = ' '.join(map(str, arch[0] + arch[1]))
                            fa.write('{}\n'.format(arch))
                            fa_latest.write('{}\n'.format(arch))
                            fp.write('{}\n'.format(perf))
                            fp_latest.write('{}\n'.format(perf))

        if epoch == args.child_epochs:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Training Encoder-Predictor-Decoder')
        encoder_input = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
                parse_arch_to_seq(x[1], 2), old_archs))
        # [[conv, reduc]]
        min_val = min(old_archs_perf)
        max_val = max(old_archs_perf)
        encoder_target = [(i - min_val) / (max_val - min_val)
                          for i in old_archs_perf]

        if args.controller_expand is not None:
            dataset = list(zip(encoder_input, encoder_target))
            n = len(dataset)
            ratio = 0.9
            split = int(n * ratio)
            np.random.shuffle(dataset)
            encoder_input, encoder_target = list(zip(*dataset))
            train_encoder_input = list(encoder_input[:split])
            train_encoder_target = list(encoder_target[:split])
            valid_encoder_input = list(encoder_input[split:])
            valid_encoder_target = list(encoder_target[split:])
            for _ in range(args.controller_expand - 1):
                for src, tgt in zip(encoder_input[:split],
                                    encoder_target[:split]):
                    a = np.random.randint(0, args.child_nodes)
                    b = np.random.randint(0, args.child_nodes)
                    src = src[:4 * a] + src[4 * a + 2:4 * a + 4] + \
                            src[4 * a:4 * a + 2] + src[4 * (a + 1):20 + 4 * b] + \
                            src[20 + 4 * b + 2:20 + 4 * b + 4] + src[20 + 4 * b:20 + 4 * b + 2] + \
                            src[20 + 4 * (b + 1):]
                    train_encoder_input.append(src)
                    train_encoder_target.append(tgt)
        else:
            train_encoder_input = encoder_input
            train_encoder_target = encoder_target
            valid_encoder_input = encoder_input
            valid_encoder_target = encoder_target
        logging.info('Train data: {}\tValid data: {}'.format(
            len(train_encoder_input), len(valid_encoder_input)))

        nao_train_dataset = utils.NAODataset(
            train_encoder_input,
            train_encoder_target,
            True,
            swap=True if args.controller_expand is None else False)
        nao_valid_dataset = utils.NAODataset(valid_encoder_input,
                                             valid_encoder_target, False)
        nao_train_queue = torch.utils.data.DataLoader(
            nao_train_dataset,
            batch_size=args.controller_batch_size,
            shuffle=True,
            pin_memory=True)
        nao_valid_queue = torch.utils.data.DataLoader(
            nao_valid_dataset,
            batch_size=args.controller_batch_size,
            shuffle=False,
            pin_memory=True)
        nao_optimizer = torch.optim.Adam(nao.parameters(),
                                         lr=args.controller_lr,
                                         weight_decay=args.controller_l2_reg)
        for nao_epoch in range(1, args.controller_epochs + 1):
            nao_loss, nao_mse, nao_ce = nao_train(nao_train_queue, nao,
                                                  nao_optimizer)
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f",
                         nao_epoch, nao_loss, nao_mse, nao_ce)
            if nao_epoch % 100 == 0:
                pa, hs = nao_valid(nao_valid_queue, nao)
                logging.info("Evaluation on valid data")
                logging.info(
                    'epoch %04d pairwise accuracy %.6f hamming distance %.6f',
                    epoch, pa, hs)

        # Generate new archs
        new_archs = []
        max_step_size = 50
        predict_step_size = 0
        top100_archs = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0], 2) + utils.
                parse_arch_to_seq(x[1], 2), old_archs[:100]))
        nao_infer_dataset = utils.NAODataset(top100_archs, None, False)
        nao_infer_queue = torch.utils.data.DataLoader(
            nao_infer_dataset,
            batch_size=len(nao_infer_dataset),
            shuffle=False,
            pin_memory=True)
        while len(new_archs) < args.controller_new_arch:
            predict_step_size += 1
            logging.info('Generate new architectures with step size %d',
                         predict_step_size)
            new_arch = nao_infer(nao_infer_queue,
                                 nao,
                                 predict_step_size,
                                 direction='+')
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= args.controller_new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > max_step_size:
                break
                # [[conv, reduc]]
        new_archs = list(
            map(lambda x: utils.parse_seq_to_arch(x, 2),
                new_archs))  # [[[conv],[reduc]]]
        num_new_archs = len(new_archs)
        logging.info("Generate %d new archs", num_new_archs)
        # replace bottom archs
        if args.controller_replace:
            new_arch_pool = old_archs[:len(old_archs) - (num_new_archs + args.controller_random_arch)] + \
                            new_archs + utils.generate_arch(args.controller_random_arch, 5, 5)
        # discard all archs except top k
        elif args.controller_discard:
            new_arch_pool = old_archs[:100] + new_archs + utils.generate_arch(
                args.controller_random_arch, 5, 5)
        # use all
        else:
            new_arch_pool = old_archs + new_archs + utils.generate_arch(
                args.controller_random_arch, 5, 5)
        logging.info("Totally %d architectures now to train",
                     len(new_arch_pool))

        child_arch_pool = new_arch_pool
        with open(os.path.join(args.output_dir, 'arch_pool'), 'w') as f:
            for arch in new_arch_pool:
                arch = ' '.join(map(str, arch[0] + arch[1]))
                f.write('{}\n'.format(arch))

        if args.child_sample_policy == 'params':
            child_arch_pool_prob = []
            for arch in child_arch_pool:
                if args.dataset == 'cifar10':
                    tmp_model = NASNetworkCIFAR(
                        args, 10, args.child_layers, args.child_nodes,
                        args.child_channels, args.child_keep_prob,
                        args.child_drop_path_keep_prob,
                        args.child_use_aux_head, args.steps, arch)
                elif args.dataset == 'cifar100':
                    tmp_model = NASNetworkCIFAR(
                        args, 100, args.child_layers, args.child_nodes,
                        args.child_channels, args.child_keep_prob,
                        args.child_drop_path_keep_prob,
                        args.child_use_aux_head, args.steps, arch)
                else:
                    tmp_model = NASNetworkImageNet(
                        args, 1000, args.child_layers, args.child_nodes,
                        args.child_channels, args.child_keep_prob,
                        args.child_drop_path_keep_prob,
                        args.child_use_aux_head, args.steps, arch)
                child_arch_pool_prob.append(
                    utils.count_parameters_in_MB(tmp_model))
                del tmp_model
        else:
            child_arch_pool_prob = None
Beispiel #3
0
def train_cifar10():
    logging.info("Args = %s", args)
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)

    global_step = tf.Variable(initial_value=0, trainable=False, dtype=tf.int32)
    epoch = tf.Variable(initial_value=0, trainable=False, dtype=tf.int32)
    best_acc_top1 = tf.Variable(initial_value=0.0,
                                trainable=False,
                                dtype=tf.float32)

    ################################################ model setup #######################################################
    train_ds, test_ds = utils.load_cifar10(args.batch_size, args.cutout_size)
    total_steps = int(np.ceil(50000 / args.batch_size)) * args.epochs

    model = NASNetworkCIFAR(classes=10,
                            reduce_distance=args.cells,
                            num_nodes=args.nodes,
                            channels=args.channels,
                            keep_prob=args.keep_prob,
                            drop_path_keep_prob=args.drop_path_keep_prob,
                            use_aux_head=args.use_aux_head,
                            steps=total_steps,
                            arch=args.arch)

    temp_ = tf.random.uniform((64, 32, 32, 3),
                              minval=0,
                              maxval=1,
                              dtype=tf.float32)
    temp_ = model(temp_, step=1, training=True)
    model.summary()
    model_size = utils.count_parameters_in_MB(model)
    print("param size = {} MB".format(model_size))
    logging.info("param size = %fMB", model_size)

    criterion = keras.losses.CategoricalCrossentropy(from_logits=True)
    learning_rate = keras.experimental.CosineDecay(
        initial_learning_rate=args.initial_lr,
        decay_steps=total_steps,
        alpha=0.0001)
    # learning_rate = keras.optimizers.schedules.ExponentialDecay(
    #     initial_learning_rate=args.initial_lr, decay_steps=total_steps, decay_rate=0.99, staircase=False, name=None
    # )
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)

    ########################################## restore checkpoint ######################################################
    if args.train_from_scratch:
        utils.clean_dir(args.model_dir)

    checkpoint_path = os.path.join(args.model_dir, 'checkpoints')
    ckpt = tf.train.Checkpoint(model=model,
                               optimizer=optimizer,
                               global_step=global_step,
                               epoch=epoch,
                               best_acc_top1=best_acc_top1)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=3)
    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print('Latest checkpoint restored!!')

    ############################################# training process #####################################################
    acc_train_result = []
    loss_train_result = []
    acc_test_result = []
    loss_test_result = []

    while epoch.numpy() < args.epochs:
        print('epoch {} lr {}'.format(epoch.numpy(),
                                      optimizer._decayed_lr(tf.float32)))

        train_acc, train_loss, step = train(train_ds,
                                            model,
                                            optimizer,
                                            global_step,
                                            criterion,
                                            classes=10)
        test_acc, test_loss = valid(test_ds, model, criterion, classes=10)

        acc_train_result.append(train_acc)
        loss_train_result.append(train_loss)
        acc_test_result.append(test_acc)
        loss_test_result.append(test_loss)

        logging.info('epoch %d lr %e', epoch.numpy(),
                     optimizer._decayed_lr(tf.float32))
        logging.info(acc_train_result)
        logging.info(loss_train_result)
        logging.info(acc_test_result)
        logging.info(loss_test_result)

        is_best = False
        if test_acc > best_acc_top1:
            best_acc_top1 = test_acc
            is_best = True
        epoch.assign_add(1)
        if (epoch.numpy() + 1) % 1 == 0:
            ckpt_save_path = ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(
                epoch.numpy() + 1, ckpt_save_path))
        if is_best:
            pass

    utils.plot_single_list(acc_train_result,
                           x_label='epochs',
                           y_label='acc',
                           file_name='acc_train')
    utils.plot_single_list(loss_train_result,
                           x_label='epochs',
                           y_label='loss',
                           file_name='loss_train')
    utils.plot_single_list(acc_test_result,
                           x_label='epochs',
                           y_label='acc',
                           file_name='acc_test')
    utils.plot_single_list(loss_test_result,
                           x_label='epochs',
                           y_label='loss',
                           file_name='loss_test')
def train_and_evaluate_top_on_cifar100(archs, train_queue, valid_queue):
    res = []
    train_criterion = nn.CrossEntropyLoss().cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    for i, arch in enumerate(archs):
        objs.reset()
        top1.reset()
        top5.reset()
        logging.info('Train and evaluate the {} arch'.format(i + 1))
        model = NASNetworkCIFAR(args, 100, args.child_layers, args.child_nodes,
                                args.child_channels, 0.6, 0.8, True,
                                args.steps, arch)
        model = model.cuda()
        model.train()
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.child_lr_max,
            momentum=0.9,
            weight_decay=args.child_l2_reg,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 10, args.child_lr_min)
        global_step = 0
        for e in range(10):
            scheduler.step()
            for step, (input, target) in enumerate(train_queue):
                input = input.cuda().requires_grad_()
                target = target.cuda()

                optimizer.zero_grad()
                # sample an arch to train
                logits, aux_logits = model(input, global_step)
                global_step += 1
                loss = train_criterion(logits, target)
                if aux_logits is not None:
                    aux_loss = train_criterion(aux_logits, target)
                    loss += 0.4 * aux_loss
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.child_grad_bound)
                optimizer.step()

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = input.size(0)
                objs.update(loss.data, n)
                top1.update(prec1.data, n)
                top5.update(prec5.data, n)

                if (step + 1) % 100 == 0:
                    logging.info('Train %3d %03d loss %e top1 %f top5 %f',
                                 e + 1, step + 1, objs.avg, top1.avg, top5.avg)
        objs.reset()
        top1.reset()
        top5.reset()
        with torch.no_grad():
            model.eval()
            for step, (input, target) in enumerate(valid_queue):
                input = input.cuda()
                target = target.cuda()

                logits, _ = model(input)
                loss = eval_criterion(logits, target)

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = input.size(0)
                objs.update(loss.data, n)
                top1.update(prec1.data, n)
                top5.update(prec5.data, n)

                if (step + 1) % 100 == 0:
                    logging.info('valid %03d %e %f %f', step + 1, objs.avg,
                                 top1.avg, top5.avg)
        res.append(top1.avg)
    return res
Beispiel #5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

    if args.dataset == 'cifar10':
        args.num_class = 10
    elif args.dataset == 'cifar100':
        args.num_class = 100
    else:
        args.num_class = 10

    if args.search_space == 'small':
        OPERATIONS = OPERATIONS_search_small
    elif args.search_space == 'middle':
        OPERATIONS = OPERATIONS_search_middle
    args.child_num_ops = len(OPERATIONS)
    args.controller_encoder_vocab_size = 1 + (args.child_nodes + 2 -
                                              1) + args.child_num_ops
    args.controller_decoder_vocab_size = args.controller_encoder_vocab_size
    args.steps = int(np.ceil(
        45000 / args.child_batch_size)) * args.child_epochs

    logging.info("args = %s", args)

    if args.child_arch_pool is not None:
        logging.info('Architecture pool is provided, loading')
        with open(args.child_arch_pool) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    elif os.path.exists(os.path.join(args.output_dir, 'arch_pool')):
        logging.info('Architecture pool is founded, loading')
        with open(os.path.join(args.output_dir, 'arch_pool')) as f:
            archs = f.read().splitlines()
            archs = list(map(utils.build_dag, archs))
            child_arch_pool = archs
    else:
        child_arch_pool = None

    build_fn = get_builder(args.dataset)
    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_fn(
        ratio=0.9, epoch=-1)

    nao = NAO(
        args.controller_encoder_layers,
        args.controller_encoder_vocab_size,
        args.controller_encoder_hidden_size,
        args.controller_encoder_dropout,
        args.controller_encoder_length,
        args.controller_source_length,
        args.controller_encoder_emb_size,
        args.controller_mlp_layers,
        args.controller_mlp_hidden_size,
        args.controller_mlp_dropout,
        args.controller_decoder_layers,
        args.controller_decoder_vocab_size,
        args.controller_decoder_hidden_size,
        args.controller_decoder_dropout,
        args.controller_decoder_length,
    )
    nao = nao.cuda()
    logging.info("Encoder-Predictor-Decoder param size = %fMB",
                 utils.count_parameters_in_MB(nao))

    if child_arch_pool is None:
        logging.info(
            'Architecture pool is not provided, randomly generating now')
        child_arch_pool = utils.generate_arch(
            args.controller_seed_arch, args.child_nodes,
            args.child_num_ops)  # [[[conv],[reduc]]]
    arch_pool = []
    arch_pool_valid_acc = []
    for i in range(4):
        logging.info('Iteration %d', i)

        child_arch_pool_prob = []
        for arch in child_arch_pool:
            if args.dataset == 'cifar10':
                tmp_model = NASNetworkCIFAR(
                    args, args.num_class, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            elif args.dataset == 'cifar100':
                tmp_model = NASNetworkCIFAR(
                    args, args.num_class, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            else:
                tmp_model = NASNetworkImageNet(
                    args, args.num_class, args.child_layers, args.child_nodes,
                    args.child_channels, args.child_keep_prob,
                    args.child_drop_path_keep_prob, args.child_use_aux_head,
                    args.steps, arch)
            child_arch_pool_prob.append(
                utils.count_parameters_in_MB(tmp_model))
            del tmp_model

        step = 0
        scheduler = get_scheduler(optimizer, args.dataset)
        for epoch in range(1, args.child_epochs + 1):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)
            # sample an arch to train
            train_acc, train_obj, step = child_train(train_queue, model,
                                                     optimizer, step,
                                                     child_arch_pool,
                                                     child_arch_pool_prob,
                                                     train_criterion)
            logging.info('train_acc %f', train_acc)

        logging.info("Evaluate seed archs")
        arch_pool += child_arch_pool
        arch_pool_valid_acc = child_valid(valid_queue, model, arch_pool,
                                          eval_criterion)

        arch_pool_valid_acc_sorted_indices = np.argsort(
            arch_pool_valid_acc)[::-1]
        arch_pool = [arch_pool[i] for i in arch_pool_valid_acc_sorted_indices]
        arch_pool_valid_acc = [
            arch_pool_valid_acc[i] for i in arch_pool_valid_acc_sorted_indices
        ]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(i)),
                  'w') as fa:
            with open(
                    os.path.join(args.output_dir,
                                 'arch_pool.perf.{}'.format(i)), 'w') as fp:
                for arch, perf in zip(arch_pool, arch_pool_valid_acc):
                    arch = ' '.join(map(str, arch[0] + arch[1]))
                    fa.write('{}\n'.format(arch))
                    fp.write('{}\n'.format(perf))
        if i == 3:
            break

        # Train Encoder-Predictor-Decoder
        logging.info('Train Encoder-Predictor-Decoder')
        encoder_input = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0]) + utils.
                parse_arch_to_seq(x[1]), arch_pool))
        # [[conv, reduc]]
        min_val = min(arch_pool_valid_acc)
        max_val = max(arch_pool_valid_acc)
        encoder_target = [(i - min_val) / (max_val - min_val)
                          for i in arch_pool_valid_acc]

        if args.controller_expand:
            dataset = list(zip(encoder_input, encoder_target))
            n = len(dataset)
            ratio = 0.9
            split = int(n * ratio)
            np.random.shuffle(dataset)
            encoder_input, encoder_target = list(zip(*dataset))
            train_encoder_input = list(encoder_input[:split])
            train_encoder_target = list(encoder_target[:split])
            valid_encoder_input = list(encoder_input[split:])
            valid_encoder_target = list(encoder_target[split:])
            for _ in range(args.controller_expand - 1):
                for src, tgt in zip(encoder_input[:split],
                                    encoder_target[:split]):
                    a = np.random.randint(0, args.child_nodes)
                    b = np.random.randint(0, args.child_nodes)
                    src = src[:4 * a] + src[4 * a + 2:4 * a + 4] + \
                            src[4 * a:4 * a + 2] + src[4 * (a + 1):20 + 4 * b] + \
                            src[20 + 4 * b + 2:20 + 4 * b + 4] + src[20 + 4 * b:20 + 4 * b + 2] + \
                            src[20 + 4 * (b + 1):]
                    train_encoder_input.append(src)
                    train_encoder_target.append(tgt)
        else:
            train_encoder_input = encoder_input
            train_encoder_target = encoder_target
            valid_encoder_input = encoder_input
            valid_encoder_target = encoder_target
        logging.info('Train data: {}\tValid data: {}'.format(
            len(train_encoder_input), len(valid_encoder_input)))

        nao_train_dataset = utils.NAODataset(
            train_encoder_input,
            train_encoder_target,
            True,
            swap=True if args.controller_expand is None else False)
        nao_valid_dataset = utils.NAODataset(valid_encoder_input,
                                             valid_encoder_target, False)
        nao_train_queue = torch.utils.data.DataLoader(
            nao_train_dataset,
            batch_size=args.controller_batch_size,
            shuffle=True,
            pin_memory=True)
        nao_valid_queue = torch.utils.data.DataLoader(
            nao_valid_dataset,
            batch_size=args.controller_batch_size,
            shuffle=False,
            pin_memory=True)
        nao_optimizer = torch.optim.Adam(nao.parameters(),
                                         lr=args.controller_lr,
                                         weight_decay=args.controller_l2_reg)
        for nao_epoch in range(1, args.controller_epochs + 1):
            nao_loss, nao_mse, nao_ce = nao_train(nao_train_queue, nao,
                                                  nao_optimizer)
            logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f",
                         nao_epoch, nao_loss, nao_mse, nao_ce)
            if nao_epoch % 100 == 0:
                pa, hs = nao_valid(nao_valid_queue, nao)
                logging.info("Evaluation on valid data")
                logging.info(
                    'epoch %04d pairwise accuracy %.6f hamming distance %.6f',
                    nao_epoch, pa, hs)

        # Generate new archs
        new_archs = []
        max_step_size = 50
        predict_step_size = 0
        top100_archs = list(
            map(
                lambda x: utils.parse_arch_to_seq(x[0]) + utils.
                parse_arch_to_seq(x[1]), arch_pool[:100]))
        nao_infer_dataset = utils.NAODataset(top100_archs, None, False)
        nao_infer_queue = torch.utils.data.DataLoader(
            nao_infer_dataset,
            batch_size=len(nao_infer_dataset),
            shuffle=False,
            pin_memory=True)
        while len(new_archs) < args.controller_new_arch:
            predict_step_size += 1
            logging.info('Generate new architectures with step size %d',
                         predict_step_size)
            new_arch = nao_infer(nao_infer_queue,
                                 nao,
                                 predict_step_size,
                                 direction='+')
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= args.controller_new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > max_step_size:
                break

        child_arch_pool = list(
            map(lambda x: utils.parse_seq_to_arch(x),
                new_archs))  # [[[conv],[reduc]]]
        logging.info("Generate %d new archs", len(child_arch_pool))

    logging.info('Finish Searching')
    logging.info('Reranking top 5 architectures')
    # reranking top 5
    top_archs = arch_pool[:5]
    if args.dataset == 'cifar10':
        top_archs_perf = train_and_evaluate_top_on_cifar10(
            top_archs, train_queue, valid_queue)
    elif args.dataset == 'cifar100':
        top_archs_perf = train_and_evaluate_top_on_cifar100(
            top_archs, train_queue, valid_queue)
    else:
        top_archs_perf = train_and_evaluate_top_on_imagenet(
            top_archs, train_queue, valid_queue)
    top_archs_sorted_indices = np.argsort(top_archs_perf)[::-1]
    top_archs = [top_archs[i] for i in top_archs_sorted_indices]
    top_archs_perf = [top_archs_perf[i] for i in top_archs_sorted_indices]
    with open(os.path.join(args.output_dir, 'arch_pool.final'), 'w') as fa:
        with open(os.path.join(args.output_dir, 'arch_pool.perf.final'),
                  'w') as fp:
            for arch, perf in zip(top_archs, top_archs_perf):
                arch = ' '.join(map(str, arch[0] + arch[1]))
                fa.write('{}\n'.format(arch))
                fp.write('{}\n'.format(perf))