示例#1
0
def main(nn_type, data_type):

    print(" Starting... ")
    filename_train = 'qa1_single-supporting-fact_train.txt'
    filename_test = 'qa1_single-supporting-fact_test.txt'
    directory = 'data/babi_tasks/tasks_1-20_v1-2/en/'
    num_epochs = 500

    #processor = Preprocessor(directory, filename_train, filename_test, data_type)
    #X_train, y_train, mask_train, X_test, y_test, mask_test, input_size, max_seq_len, idx2word = processor.extract_data()

    #wProc = WikiProcessor('C:/Users/Dan/Desktop/Crore/6.864/Project/Data/wiki_qa/')
    #wProc.process()

    #proc = CNNProcessor()
    # proc.process()


    if nn_type == "lstm":
        proc = BabiProcessor(data_type)
        X_train, y_train, mask_train, X_test, y_test, mask_test, input_size, max_seq_len, idx2word = proc.process()
        lstm = LSTM(X_train, y_train, mask_train, X_test, y_test, mask_test, idx2word)
        network, l_mask, l_in = lstm.build_model(input_size, max_seq_len)
        lstm.optimize(network, l_mask, l_in)

    elif nn_type == "mem_net" and data_type == "babi":
        mn = MemNet()
        mn.run('babi')
    elif nn_type == "mem_net" and data_type == "wiki_qa":
        mn = MemNet()
        mn.run(data_type)
    elif nn_type == "mem_net" and data_type == "cnn":
        mn = MemNet()
        mn.run('cnn_qa')

    elif nn_type == "dynam_net":
        proc = BabiProcessor(data_type, "dynam_net")
        X_train, Q_train, Y_train, mask_train, X_test, Q_test, Y_test, mask_test, input_size, max_seqlen, idx2word, max_queslen = proc.process()
        dn = DynamicMemNet(X_train, Q_train, Y_train, mask_train, X_test, Q_test, Y_test, mask_test, input_size, max_seqlen, idx2word, max_queslen)
        dn.build()
        dn.train()
    elif nn_type == "dynam_net_theano":
        #num_fact_hidden_units, number_classes, number_fact_embeddings, dimension_fact_embeddings, num_episode_hidden_units

        dmn_t = DMN_full_babi()
        dmn_t.train()
        print("Finished DMN Theano")