示例#1
0
文件: ann.py 项目: afcarl/SciProjects
def run():
    fruits = CData(gyumpath, gyumindeps, feature=TAXLEVEL, cross_val=CROSSVAL)
    fruits.transformation = (TRANSFORMATION, TRANSFORMATION_PARAM)

    network = build_net(*fruits.neurons_required)

    testing = fruits.table("testing")
    zsind = CData(zsindpath, zsindeps, cross_val=0.0, feature=TAXLEVEL)
    vx, vy = zsind.learning, zsind.lindeps
    vx = fruits.transformation(vx)
    vy = fruits.embed(vy)

    initc, initacc = network.evaluate(*testing, verbose=0)
    initc, initacc = round(initc, 5), round(initacc, 5)
    print("Initial cost: {}\tacc: {}".format(initc, initacc))

    X, y = fruits.table("learning")
    network.fit(X,
                y,
                batch_size=20,
                nb_epoch=400,
                validation_data=testing,
                verbose=0)
    tacc = network.evaluate(*testing, batch_size=fruits.n_testing,
                            verbose=0)[-1]
    vacc = network.evaluate(vx, vy, verbose=0)[-1]
    # batchgen = fruits.batchgen(100, infinite=True)
    # log = network.fit_generator(batchgen, fruits.N, nb_epoch=15, validation_data=valid, verbose=verbose)
    print("T: {}\tV: {}".format(tacc, vacc))
    return tacc, vacc
示例#2
0
文件: ann.py 项目: afcarl/SciProjects
def full_training(validate=True, dump_weights=False):
    fruits = CData(gyumpath, gyumindeps, feature=TAXLEVEL, cross_val=0.0)
    fruits.transformation = (TRANSFORMATION, TRANSFORMATION_PARAM)

    network = build_net(*fruits.neurons_required)

    X, y = fruits.table("learning")
    network.fit(X, y, batch_size=30, nb_epoch=500, verbose=0)

    if dump_weights:
        weights = network.layers[0].get_weights()

        wghts = open("weights.csv", "w")
        wghts.write("\n".join([
            "\t".join([str(float(cell)) for cell in line])
            for line in weights[0].T
        ]).replace(".", ","))
        wghts.close()

    if validate:
        vx, vy = CData(zsindpath, zsindeps, cross_val=0.0, feature=TAXLEVEL)
        vx = fruits.transformation(vx)
        vy = fruits.embed(vy)
        vacc = network.evaluate(vx, vy, batch_size=len(vy), verbose=0)[-1]
        probs = network.predict_proba(vx, verbose=0)
        preds = network.predict_classes(vx, verbose=0)
        print("ANN validation accuracy:", vacc)
        return probs, preds, vy, fruits