def cross_validation_test():
    glogger.setLoggingLevel(glogger.nothing)

    filename = "/home/gibson/jonask/Dropbox/Ann-Survival-Phd/Two_thirds_of_SA_1889_dataset.txt"

    #try:
    #    columns = input("Which columns to include? (Do NOT forget trailing comma if only one column is used, e.g. '3,'\nAvailable columns are: 2, -4, -3, -2, -1. Just press ENTER for all columns.\n")
    #except SyntaxError:
    columns = (2, -4, -3, -2, -1)
    print('\nIncluding columns: ' + str(columns))

    P, T = parse_file(filename, targetcols = [4, 5], inputcols = columns, ignorerows = [0], normalize = True)
    #remove tail censored
    P, T = copy_without_tailcensored(P, T)

    #try:
    #    comsize = input("Number of networks to cross-validate [10]: ")
    #except SyntaxError:
    comsize = 10
    print('Number of networks to cross-validate: ' + str(comsize))

    #try:
    #    netsize = input('Number of hidden nodes [3]: ')
    #except SyntaxError as e:
    if len(sys.argv) < 2:
        netsize = 3
    else:
        netsize = sys.argv[1]
    print("Number of hidden nodes: " + str(netsize))

    #try:
    #    pop_size = input('Population size [50]: ')
    #except SyntaxError as e:
    pop_size = 50
    print("Population size: " + str(pop_size))

    #try:
    #    mutation_rate = input('Please input a mutation rate (0.25): ')
    #except SyntaxError as e:
    mutation_rate = 0.25
    print("Mutation rate: " + str(mutation_rate))

    #try:
    #    epochs = input("Number of generations (200): ")
    #except SyntaxError as e:
    epochs = 200
    print("Epochs: " + str(epochs))

    com = build_feedforward_committee(comsize, len(P[0]), netsize, 1, output_function = 'linear')

    #1 is the column in the target array which holds teh binary censoring information
    test_errors, vald_errors = train_committee(com, train_evolutionary, P, T, 1, epochs, error_function = c_index_error, population_size = pop_size, mutation_chance = mutation_rate)

    print('\nTest Errors, Validation Errors:')
    for terr, verr in zip(test_errors.values(), vald_errors.values()):
        print(str(terr) + ", " + str(verr))

    print('\nTest average, Validation average:')
    print(str(sum(test_errors.values()) / len(test_errors.values())) + ', ' + str(sum(vald_errors.values()) / len(vald_errors.values())))
示例#2
0
def cpu_multi(pop_size, P, num_of_hidden):
    # number of input covariates
    num_of_inputs = len(P[0])

    # Tanh to keep implementation details easy in opencl
    com = build_feedforward_committee(8, num_of_inputs, num_of_hidden, 1, output_function="tanh")

    benchmark(com.sim)(P)
test_inputs, tst_t = parse_file(filename)


test = (inputs, targets)
validation = ([], [])

# Train!

# net = build_feedforward(8, 8, 1)

epochs = 10

# best = benchmark(train_evolutionary)(net, test, validation, 10, random_range = 1)
# best = benchmark(traingd_block)(net, test, validation, epochs, block_size = 10, stop_error_value = 0)

com = build_feedforward_committee(size=10, input_number=8, hidden_number=8, output_number=1)

print "Training evolutionary..."
benchmark(train_committee)(com, train_evolutionary, inputs, targets, epochs, random_range=1)

Y = com.sim(inputs)
area, best_cut = get_rocarea_and_best_cut(Y, targets)
[num_correct_first, num_correct_second, total_performance, num_first, num_second, missed] = stat(
    Y, targets, cut=best_cut
)
print (
    "Total number of data: " + str(len(targets)) + " (" + str(num_second) + " ones and " + str(num_first) + " zeros)"
)
print ("Number of misses: " + str(missed) + " (" + str(total_performance) + "% performance)")
print ("Specificity: " + str(num_correct_first) + "% (Success for class 0)")
print ("Sensitivity: " + str(num_correct_second) + "% (Success for class 1)")
def cross_validation_test():

    filename = "/home/gibson/jonask/Dropbox/Ann-Survival-Phd/Two_thirds_of_SA_1889_dataset.txt"

    #try:
    #    columns = input("Which columns to include? (Do NOT forget trailing comma if only one column is used, e.g. '3,'\nAvailable columns are: 2, -4, -3, -2, -1. Just press ENTER for all columns.\n")
    #except SyntaxError:
    if len(sys.argv) < 3:
        columns = (2, -4, -3, -2, -1)
    else:
        columns = [int(col) for col in sys.argv[2:]]

    print('\nIncluding columns: ' + str(columns))

    P, T = parse_file(filename, targetcols = [4, 5], inputcols = columns, ignorerows = [0], normalize = True)
    #remove tail censored
    #print('\nRemoving tail censored...')
    #P, T = copy_without_censored(P, T)

    print("\nData set:")
    print("Number of patients with events: " + str(T[:, 1].sum()))
    print("Number of censored patients: " + str((1 - T[:, 1]).sum()))

    #try:
    #    comsize = input("Number of networks to cross-validate [10]: ")
    #except SyntaxError:
    comsize = 5
    print('\nNumber of networks to cross-validate: ' + str(comsize))

    times_to_cross = 3
    print('\nNumber of times to repeat cross-validation: ' + str(times_to_cross))

    #try:
    #    netsize = input('Number of hidden nodes [3]: ')
    #except SyntaxError as e:
    if len(sys.argv) < 2:
        netsize = 1
    else:
        netsize = int(sys.argv[1])
    print("Number of hidden nodes: " + str(netsize))

    #try:
    #    pop_size = input('Population size [50]: ')
    #except SyntaxError as e:
    pop_size = 100
    print("Population size: " + str(pop_size))

    #try:
    #    mutation_rate = input('Please input a mutation rate (0.25): ')
    #except SyntaxError as e:
    mutation_rate = 0.05
    print("Mutation rate: " + str(mutation_rate))

    #try:
    #    epochs = input("Number of generations (200): ")
    #except SyntaxError as e:
    epochs = 400
    print("Epochs: " + str(epochs))

    for _ in xrange(times_to_cross):
        com = build_feedforward_committee(comsize, len(P[0]), netsize, 1, output_function = 'linear')

        #1 is the column in the target array which holds the binary censoring information
        test_errors, vald_errors, data_sets = train_committee(com, train_evolutionary, P, T, 1, epochs, error_function = c_index_error, population_size = pop_size, mutation_chance = mutation_rate)

        print('\nTest Errors, Validation Errors:')
        for terr, verr in zip(test_errors.values(), vald_errors.values()):
            print(str(terr) + ", " + str(verr))