示例#1
0
def gen_gp_test_data():
    """ Generates test data for the unit tests. """
    # CNNs
    cnns = generate_cnn_architectures()
    np.random.shuffle(cnns)
    X1_tr = cnns[:7]
    Y1_tr = np.array([cnn_syn_func1(x) for x in X1_tr])
    X1_te = cnns[7:]
    Y1_te = np.array([cnn_syn_func1(x) for x in X1_te])
    nn_type_1 = 'cnn'
    # MLP regression
    regmlps = generate_mlp_architectures('reg')
    np.random.shuffle(regmlps)
    X2_tr = regmlps[:5]
    Y2_tr = np.array([mlp_syn_func1(x) for x in X2_tr])
    X2_te = regmlps[5:]
    Y2_te = np.array([mlp_syn_func1(x) for x in X2_te])
    nn_type_2 = 'mlp-reg'
    # MLP classification
    classmlps = generate_mlp_architectures('class')
    np.random.shuffle(classmlps)
    X3_tr = classmlps[:5]
    Y3_tr = np.array([syn_func1_common(x) for x in X3_tr])
    X3_te = classmlps[5:]
    Y3_te = np.array([syn_func1_common(x) for x in X3_te])
    nn_type_3 = 'mlp-class'
    return [(X1_tr, Y1_tr, X1_te, Y1_te, nn_type_1),
            (X2_tr, Y2_tr, X2_te, Y2_te, nn_type_2),
            (X3_tr, Y3_tr, X3_te, Y3_te, nn_type_3)]
示例#2
0
def get_nn_opt_arguments():
  """ Returns arguments for NN Optimisation. """
  ret = Namespace()
  ret.cnn_constraint_checker = CNNConstraintChecker(50, 5, 1e8, 0, 5, 5, 200, 1024, 8)
  ret.mlp_constraint_checker = MLPConstraintChecker(50, 5, 1e8, 0, 5, 5, 200, 1024, 8)
  ret.cnn_mutation_op = get_nn_modifier_from_args(ret.cnn_constraint_checker,
                                                   [0.5, 0.25, 0.125, 0.075, 0.05])
  ret.mlp_mutation_op = get_nn_modifier_from_args(ret.mlp_constraint_checker,
                                                   [0.5, 0.25, 0.125, 0.075, 0.05])
  # Create the initial pool
  ret.cnn_init_pool = get_initial_cnn_pool()
  ret.cnn_init_vals = [cnn_syn_func1(cnn) for cnn in ret.cnn_init_pool]
  ret.mlp_init_pool = get_initial_mlp_pool('reg')
  ret.mlp_init_vals = [mlp_syn_func1(mlp) for mlp in ret.mlp_init_pool]
  # Create a domain
  ret.nn_domain = NNDomain(None, None)
  ret.cnn_domain = NNDomain('cnn', ret.cnn_constraint_checker)
  ret.mlp_domain = NNDomain('mlp-reg', ret.mlp_constraint_checker)
  # Return
  return ret
示例#3
0
def syn_cnn_1(x):
  """ Computes the Branin function. """
  return cnn_syn_func1(x[0])