Exemple #1
0
def generate_synthetic_controller_data(nasbench,
                                       model,
                                       base_arch=None,
                                       random_arch=0,
                                       direction='+'):
    random_synthetic_input = []
    random_synthetic_target = []
    if random_arch > 0:
        while len(random_synthetic_input) < random_arch:
            seq = utils.generate_arch(1, nasbench)[1][0]
            if seq not in random_synthetic_input and seq not in base_arch:
                random_synthetic_input.append(seq)

        controller_synthetic_dataset = utils.ControllerDataset(
            random_synthetic_input, None, False)
        controller_synthetic_queue = torch.utils.data.DataLoader(
            controller_synthetic_dataset,
            batch_size=len(controller_synthetic_dataset),
            shuffle=False,
            pin_memory=True)

        with torch.no_grad():
            model.eval()
            for sample in controller_synthetic_queue:
                encoder_input = sample['encoder_input'].cuda()
                _, _, _, predict_value = model.encoder(encoder_input)
                random_synthetic_target += predict_value.data.squeeze().tolist(
                )
        assert len(random_synthetic_input) == len(random_synthetic_target)
    synthetic_input = random_synthetic_input
    synthetic_target = random_synthetic_target
    assert len(synthetic_input) == len(synthetic_target)
    return synthetic_input, synthetic_target
Exemple #2
0
def train_controller(model, train_input, train_target, epochs):
    logging.info('Train data: {}'.format(len(train_input)))
    controller_train_dataset = utils.ControllerDataset(train_input,
                                                       train_target, True)
    controller_train_queue = torch.utils.data.DataLoader(
        controller_train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.l2_reg)
    for epoch in range(1, epochs + 1):
        loss, mse, ce = controller_train(controller_train_queue, model,
                                         optimizer)
        logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f", epoch,
                     loss, mse, ce)
Exemple #3
0
def generate_synthetic_controller_data(nasbench,
                                       model,
                                       base_arch=None,
                                       random_arch=0):
    random_synthetic_input = []
    random_synthetic_target = []
    if random_arch > 0:
        all_keys = list(nasbench.hash_iterator())
        np.random.shuffle(all_keys)
        i = 0
        while len(random_synthetic_input) < random_arch:
            key = all_keys[i]
            fs, cs = nasbench.get_metrics_from_hash(key)
            seq = utils.convert_arch_to_seq(fs['module_adjacency'],
                                            fs['module_operations'])
            if seq not in random_synthetic_input and seq not in base_arch:
                random_synthetic_input.append(seq)
            i += 1

        nao_synthetic_dataset = utils.ControllerDataset(
            random_synthetic_input, None, False)
        nao_synthetic_queue = torch.utils.data.DataLoader(
            nao_synthetic_dataset,
            batch_size=len(nao_synthetic_dataset),
            shuffle=False,
            pin_memory=True)

        with torch.no_grad():
            model.eval()
            for sample in nao_synthetic_queue:
                encoder_input = sample['encoder_input'].cuda()
                _, _, _, predict_value = model.encoder(encoder_input)
                random_synthetic_target += predict_value.data.squeeze().tolist(
                )
        assert len(random_synthetic_input) == len(random_synthetic_target)

    synthetic_input = random_synthetic_input
    synthetic_target = random_synthetic_target
    assert len(synthetic_input) == len(synthetic_target)
    return synthetic_input, synthetic_target
def main(mode, data_file_path):
    # mode = "train"
    # data_file_path = "data/train_data.json"
    random.seed(conf.seed)
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)
    logging.info("conf = %s", conf)

    conf.source_length = conf.encoder_length = conf.decoder_length = (
        conf.graph_size + 2) * (conf.graph_size - 1) // 2
    epochs = conf.epochs

    model = Graph2Seq(mode=mode, conf=conf)

    # load data
    dataset = utils.ControllerDataset(data_file_path)
    queue = torch.utils.data.DataLoader(dataset,
                                        batch_size=conf.batch_size,
                                        shuffle=True,
                                        pin_memory=True,
                                        collate_fn=utils.collate_fn)

    if mode == "train":

        model.train()
        logging.info('Train data: {}'.format(len(queue)))
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=conf.learning_rate,
                                     weight_decay=conf.l2_reg)

        def train_step(train_queue, optimizer):
            objs = utils.AvgrageMeter()
            nll = utils.AvgrageMeter()

            for step, sample in enumerate(train_queue):
                fw_adjs = sample['fw_adjs']
                bw_adjs = sample['bw_adjs']
                operations = sample['operations']
                num_nodes = sample['num_nodes']
                sequence = sample['sequence']

                optimizer.zero_grad()
                log_prob, predicted_value = model(fw_adjs,
                                                  bw_adjs,
                                                  operations,
                                                  num_nodes,
                                                  targets=sequence)
                # print("input: {} output : {}".format(log_prob.size(), sequence.size()))
                loss = F.nll_loss(
                    log_prob.contiguous().view(-1, log_prob.size(-1)),
                    sequence.view(-1))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               conf.grad_bound)
                optimizer.step()

                n = sequence.size(0)
                objs.update(loss.data, n)
                nll.update(loss.data, n)
                # logging.info("step : %04d, objs: %.6f, nll : %.6f", step, objs,avgs, nll)

            return objs.avg, nll.avg

        ## Check with one epoch
        epoch = 1
        for epoch in range(1, epochs + 1):
            loss, ce = train_step(queue, optimizer)
            logging.info("epoch %04d train loss %.6f ce %.6f", epoch, loss, ce)

        ## save trainable parameters
        torch.save(model.state_dict(), conf.model_path)

    if mode == "test":

        model.load_state_dict(torch.load(conf.model_path))
        model.eval()

        def test_step(test_queue):
            match = 0
            total = 0
            for step, sample in enumerate(test_queue):
                fw_adjs = sample['fw_adjs']
                bw_adjs = sample['bw_adjs']
                operations = sample['operations']
                num_nodes = sample['num_nodes']
                sequence = sample['sequence']

                log_prob, predicted_value = model(fw_adjs, bw_adjs, operations,
                                                  num_nodes)

                match = torch.all(torch.equal(predicted_value, sequence),
                                  dim=1)
                total += len(num_nodes)

            accuracy = match / predicted_value.size(0)
            return accuracy

        logging.info('Test data: {}'.format(len(queue)))
        for epoch in range(1, epochs + 1):
            accuracy = test_step(queue)
            logging.info("epoch %04d accuracy %.6f", epoch, accuracy)
Exemple #5
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU found!')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True

    logging.info("Args = %s", args)

    args.source_length = args.encoder_length = args.decoder_length = (
        args.nodes + 2) * (args.nodes - 1) // 2

    nasbench = api.NASBench(os.path.join(args.data, 'nasbench_full.tfrecord'))

    controller = NAO(
        args.encoder_layers,
        args.decoder_layers,
        args.mlp_layers,
        args.mlp_hidden_size,
        args.hidden_size,
        args.vocab_size,
        args.dropout,
        args.source_length,
        args.encoder_length,
        args.decoder_length,
    )
    logging.info("param size = %d", utils.count_parameters(controller))
    controller = controller.cuda()

    child_arch_pool, child_seq_pool, child_arch_pool_valid_acc = utils.generate_arch(
        args.seed_arch, nasbench, need_perf=True)

    arch_pool = []
    seq_pool = []
    arch_pool_valid_acc = []
    for i in range(args.iteration + 1):
        logging.info('Iteration {}'.format(i + 1))
        if not child_arch_pool_valid_acc:
            for arch in child_arch_pool:
                data = nasbench.query(arch)
                child_arch_pool_valid_acc.append(data['validation_accuracy'])

        arch_pool += child_arch_pool
        arch_pool_valid_acc += child_arch_pool_valid_acc
        seq_pool += child_seq_pool

        arch_pool_valid_acc_sorted_indices = np.argsort(
            arch_pool_valid_acc)[::-1]
        arch_pool = [arch_pool[i] for i in arch_pool_valid_acc_sorted_indices]
        seq_pool = [seq_pool[i] for i in arch_pool_valid_acc_sorted_indices]
        arch_pool_valid_acc = [
            arch_pool_valid_acc[i] for i in arch_pool_valid_acc_sorted_indices
        ]
        with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(i)),
                  'w') as fa:
            for arch, seq, valid_acc in zip(arch_pool, seq_pool,
                                            arch_pool_valid_acc):
                fa.write('{}\t{}\t{}\t{}\n'.format(arch.matrix, arch.ops, seq,
                                                   valid_acc))
        for arch_index in range(10):
            print('Top 10 architectures:')
            print('Architecutre connection:{}'.format(
                arch_pool[arch_index].matrix))
            print('Architecture operations:{}'.format(
                arch_pool[arch_index].ops))
            print('Valid accuracy:{}'.format(arch_pool_valid_acc[arch_index]))

        if i == args.iteration:
            print('Final architectures:')
            for arch_index in range(10):
                print('Architecutre connection:{}'.format(
                    arch_pool[arch_index].matrix))
                print('Architecture operations:{}'.format(
                    arch_pool[arch_index].ops))
                print('Valid accuracy:{}'.format(
                    arch_pool_valid_acc[arch_index]))
                fs, cs = nasbench.get_metrics_from_spec(arch_pool[arch_index])
                test_acc = np.mean(
                    [cs[108][j]['final_test_accuracy'] for j in range(3)])
                print('Mean test accuracy:{}'.format(test_acc))
            break

        train_encoder_input = seq_pool
        min_val = min(arch_pool_valid_acc)
        max_val = max(arch_pool_valid_acc)
        train_encoder_target = [(i - min_val) / (max_val - min_val)
                                for i in arch_pool_valid_acc]

        # Pre-train
        logging.info('Pre-train EPD')
        train_controller(controller, train_encoder_input, train_encoder_target,
                         args.pretrain_epochs)
        logging.info('Finish pre-training EPD')
        # Generate synthetic data
        logging.info('Generate synthetic data for EPD')
        synthetic_encoder_input, synthetic_encoder_target = generate_synthetic_controller_data(
            nasbench, controller, train_encoder_input, args.random_arch)
        if args.up_sample_ratio is None:
            up_sample_ratio = np.ceil(args.random_arch /
                                      len(train_encoder_input)).astype(np.int)
        else:
            up_sample_ratio = args.up_sample_ratio
        all_encoder_input = train_encoder_input * up_sample_ratio + synthetic_encoder_input
        all_encoder_target = train_encoder_target * up_sample_ratio + synthetic_encoder_target
        # Train
        logging.info('Train EPD')
        train_controller(controller, all_encoder_input, all_encoder_target,
                         args.epochs)
        logging.info('Finish training EPD')

        new_archs = []
        new_seqs = []
        predict_step_size = 0
        unique_input = train_encoder_input + synthetic_encoder_input
        unique_target = train_encoder_target + synthetic_encoder_target
        unique_indices = np.argsort(unique_target)[::-1]
        unique_input = [unique_input[i] for i in unique_indices]
        topk_archs = unique_input[:args.k]
        controller_infer_dataset = utils.ControllerDataset(
            topk_archs, None, False)
        controller_infer_queue = torch.utils.data.DataLoader(
            controller_infer_dataset,
            batch_size=len(controller_infer_dataset),
            shuffle=False,
            pin_memory=True)

        while len(new_archs) < args.new_arch:
            predict_step_size += 1
            logging.info('Generate new architectures with step size %d',
                         predict_step_size)
            new_seq, new_perfs = controller_infer(controller_infer_queue,
                                                  controller,
                                                  predict_step_size,
                                                  direction='+')
            for seq in new_seq:
                matrix, ops = utils.convert_seq_to_arch(seq)
                arch = api.ModelSpec(matrix=matrix, ops=ops)
                if nasbench.is_valid(arch) and len(
                        arch.ops
                ) == 7 and seq not in train_encoder_input and seq not in new_seqs:
                    new_archs.append(arch)
                    new_seqs.append(seq)
                if len(new_seqs) >= args.new_arch:
                    break
            logging.info('%d new archs generated now', len(new_archs))
            if predict_step_size > args.max_step_size:
                break

        child_arch_pool = new_archs
        child_seq_pool = new_seqs
        child_arch_pool_valid_acc = []
        child_arch_pool_test_acc = []
        logging.info("Generate %d new archs", len(child_arch_pool))

    print(nasbench.get_budget_counters())