def analysis():
    search_space_1 = SearchSpace1()
    search_space_1.sample(with_loose_ends=False)
    # Load NASBench
    nasbench = NasbenchWrapper('nasbench_analysis/nasbench_data/108_e/nasbench_full.tfrecord')

    test_error = []
    valid_error = []

    search_space_1 = SearchSpace1()
    search_space_1.sample_with_loose_ends()

    for i in range(10000):
        adjacency_matrix, node_list = search_space_1.sample()
        adjacency_list = adjacency_matrix.astype(np.int).tolist()
        node_list = [INPUT, *node_list, OUTPUT]
        model_spec = api.ModelSpec(matrix=adjacency_list, ops=node_list)
        nasbench.query(model_spec)

    for adjacency_matrix, ops, model_spec in search_space_1.generate_search_space_without_loose_ends():
        # Query NASBench
        data = nasbench.query(model_spec)
        for item in data:
            test_error.append(1 - item['test_accuracy'])
            valid_error.append(1 - item['validation_accuracy'])

    print('Number of architectures', len(test_error) / len(data))

    plt.figure()
    plt.title(
        'Distribution of test error in search space (no. architectures {})'.format(
            int(len(test_error) / len(data))))
    plt.hist(test_error, bins=800, density=True)
    ax = plt.gca()
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.xlabel('Test error')
    plt.grid(True, which="both", ls="-", alpha=0.5)
    plt.tight_layout()
    plt.xlim(0, 0.3)
    plt.savefig('nasbench_analysis/search_spaces/export/search_space_1/test_error_distribution.pdf', dpi=600)
    plt.show()

    plt.figure()
    plt.title('Distribution of validation error in search space (no. architectures {})'.format(
        int(len(valid_error) / len(data))))
    plt.hist(valid_error, bins=800, density=True)
    ax = plt.gca()
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.xlabel('Validation error')
    plt.grid(True, which="both", ls="-", alpha=0.5)
    plt.tight_layout()
    plt.xlim(0, 0.3)
    plt.savefig('nasbench_analysis/search_spaces/export/search_space_1/valid_error_distribution.pdf', dpi=600)
    plt.show()

    print('test_error', min(test_error), 'valid_error', min(valid_error))
def eval_one_shot_model(config, model):
    nasbench = NasbenchWrapper(
        dataset_file=
        '/home/darts_weight_sharing_analysis/cnn/bohb/src/nasbench_analysis/nasbench_data/108_e/nasbench_only108.tfrecord'
    )
    model_list = pickle.load(open(model, 'rb'))

    alphas_mixed_op = model_list[0]
    chosen_node_ops = softmax(alphas_mixed_op, axis=-1).argmax(-1)

    node_list = [PRIMITIVES[i] for i in chosen_node_ops]
    alphas_output = model_list[1]
    alphas_inputs = model_list[2:]

    if config['search_space'] == '1':
        search_space = SearchSpace1()
        num_inputs = list(search_space.num_parents_per_node.values())[3:-1]
        parents_node_3, parents_node_4 = \
            [get_top_k(softmax(alpha, axis=1), num_input) for num_input, alpha in zip(num_inputs, alphas_inputs)]
        output_parents = get_top_k(softmax(alphas_output), num_inputs[-1])
        parents = {
            '0': [],
            '1': [0],
            '2': [0, 1],
            '3': parents_node_3,
            '4': parents_node_4,
            '5': output_parents
        }
        node_list = [INPUT, *node_list, CONV1X1, OUTPUT]

    elif config['search_space'] == '2':
        search_space = SearchSpace2()
        num_inputs = list(search_space.num_parents_per_node.values())[2:]
        parents_node_2, parents_node_3, parents_node_4 = \
            [get_top_k(softmax(alpha, axis=1), num_input) for num_input, alpha in zip(num_inputs[:-1], alphas_inputs)]
        output_parents = get_top_k(softmax(alphas_output), num_inputs[-1])
        parents = {
            '0': [],
            '1': [0],
            '2': parents_node_2,
            '3': parents_node_3,
            '4': parents_node_4,
            '5': output_parents
        }
        node_list = [INPUT, *node_list, CONV1X1, OUTPUT]

    elif config['search_space'] == '3':
        search_space = SearchSpace3()
        num_inputs = list(search_space.num_parents_per_node.values())[2:]
        parents_node_2, parents_node_3, parents_node_4, parents_node_5 = \
            [get_top_k(softmax(alpha, axis=1), num_input) for num_input, alpha in zip(num_inputs[:-1], alphas_inputs)]
        output_parents = get_top_k(softmax(alphas_output), num_inputs[-1])
        parents = {
            '0': [],
            '1': [0],
            '2': parents_node_2,
            '3': parents_node_3,
            '4': parents_node_4,
            '5': parents_node_5,
            '6': output_parents
        }
        node_list = [INPUT, *node_list, OUTPUT]

    else:
        raise ValueError('Unknown search space')

    adjacency_matrix = search_space.create_nasbench_adjacency_matrix(parents)
    # Convert the adjacency matrix in format for nasbench
    adjacency_list = adjacency_matrix.astype(np.int).tolist()
    model_spec = api.ModelSpec(matrix=adjacency_list, ops=node_list)
    # Query nasbench
    data = nasbench.query(model_spec)
    valid_error, test_error, runtime, params = [], [], [], []
    for item in data:
        test_error.append(1 - item['test_accuracy'])
        valid_error.append(1 - item['validation_accuracy'])
        runtime.append(item['training_time'])
        params.append(item['trainable_parameters'])
    return test_error, valid_error, runtime, params
def analysis():
    search_space_1 = SearchSpace(num_parents_per_node={
        '0': 0,
        '1': 1,
        '2': 2,
        '3': 2,
        '4': 2,
        '5': 2
    },
                                 search_space_number=1,
                                 num_intermediate_nodes=4)

    # Load NASBench
    nasbench = NasbenchWrapper(
        '/home/siemsj/projects/darts_weight_sharing_analysis/nasbench_full.tfrecord'
    )

    test_error = []
    valid_error = []

    search_space_creator = search_space_1.create_search_space(
        with_loose_ends=False, upscale=False)
    for adjacency_matrix, ops, model_spec in search_space_creator:
        # Query NASBench
        data = nasbench.query(model_spec)
        for item in data:
            test_error.append(1 - item['test_accuracy'])
            valid_error.append(1 - item['validation_accuracy'])

    print('Number of architectures', len(test_error) / len(data))

    plt.figure()
    plt.title(
        'Distribution of test error in search space (no. architectures {})'.
        format(int(len(test_error) / len(data))))
    plt.hist(test_error, bins=800, density=True)
    ax = plt.gca()
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.xlabel('Test error')
    plt.grid(True, which="both", ls="-", alpha=0.5)
    plt.tight_layout()
    plt.xlim(0, 0.3)
    plt.savefig(
        'nasbench_analysis/search_spaces/export/search_space_1/test_error_distribution.pdf',
        dpi=600)
    plt.show()

    plt.figure()
    plt.title(
        'Distribution of validation error in search space (no. architectures {})'
        .format(int(len(valid_error) / len(data))))
    plt.hist(valid_error, bins=800, density=True)
    ax = plt.gca()
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.xlabel('Validation error')
    plt.grid(True, which="both", ls="-", alpha=0.5)
    plt.tight_layout()
    plt.xlim(0, 0.3)
    plt.savefig(
        'nasbench_analysis/search_spaces/export/search_space_1/valid_error_distribution.pdf',
        dpi=600)
    plt.show()

    print('test_error', min(test_error), 'valid_error', min(valid_error))