示例#1
0
def main():
    #todo this funciton and taining should become part of the library!!
    # sodass man nur mehr savepath und dataset angeben muss!

    msg.info("Traing a network for butadien", 2)

    msg.info("Fetching dataset ... ", 2)
    dataset = prep_dataset()

    save_path = "butadien/data/networks/networkS400.npy"


    user_input =  msg.input(
        "This will overwrite the model at " + save_path + \
        "Are you sure you want that? (y for yes)"
    )

    if user_input.upper() != "Y":
        msg.info("Aborting", 2)
        return

    msg.info("Try to fetch current model")
    try:

        model = np.load(save_path, encoding="latin1")
        structure, weights, biases = model[0], model[1], model[2]
        network = EluFixedValue(structure, weights, biases)
        test_error = model[3]

        user_input =  msg.input(
            "Model found with test error :  " + str(test_error) + \
            ". Do you want to continue to train it? (y for yes)"
        )

        if user_input.upper() != "Y":
            msg.info("Creating new network", 2)
            model = None

    except:
        model = None

    if model is None:
        dim_triu = int(DIM * (DIM + 1) / 2)
        structure = [
            dim_triu,
            int(dim_triu * 0.75),
            int(dim_triu * 0.5), dim_triu, dim_triu
        ]
        test_error = 1e10

    msg.info("Train ... ", 2)

    network = EluTrNNN(structure)

    train_network(dataset, network, save_path, test_error)

    msg.info("All done. Bye bye..", 2)
示例#2
0
def main():

    msg.info("Traing a network for butadien", 2)

    msg.info("Fetching dataset ... ", 2)
    dataset = prep_dataset()

    msg.info("Train ... ", 2)
    trainer, network, sess = train_network(dataset)

    user_input = msg.input("Keep this network for butadien " + " (y for yes)?")

    if user_input.upper() == "Y":
        save_path = join("butadien/data", "network.npy")
        network.export(sess, save_path)
        msg.info("Exported network to: " + save_path, 2)
    else:
        msg.info("Network discarded ...", 2)

    msg.info("All done. Bye bye..", 2)
示例#3
0
def main(molecule_type):

    msg.info("Traing a network for " + molecule_type, 2)

    msg.info("Fetching dataset ... ", 2)
    dataset = prep_dataset(molecule_type)

    msg.info("Train ... ", 2)
    trainer, network, sess = train_network(molecule_type, dataset)

    user_input = msg.input("Keep this network for " + molecule_type +
                           " (y for yes)?")

    if user_input.upper() == "Y":
        save_path = \
            join("cc2ai", molecule_type, "network_" +  molecule_type + ".npy")
        network.export(sess, save_path)
        msg.info("Exported network to: " + save_path, 2)
    else:
        msg.info("Network discarded ...", 2)

    msg.info("All done. Bye bye..", 2)