def test_srl_neural_model(train_file_path, valid_file_path):

    train_corpora = Conll05Corpora()
    train_corpora.load(train_file_path)
    train_problem = SRLProblem(train_corpora)

    valid_corpora = Conll05Corpora()
    valid_corpora.load(valid_file_path)
    valid_problem = SRLProblem(valid_corpora)

    init_rng()

    nn_architecture = SRLNetowrkArchitecture()

    nn_architecture.word_feature_dim = 50
    nn_architecture.pos_feature_dim = 50
    nn_architecture.dist_feature_dim = 50

    nn_architecture.conv_window_height = 3
    nn_architecture.conv_output_dim = 500

    nn_architecture.hidden_layer_output_dims = [500,500]


    hyper_param = NeuralModelHyperParameter()

    hyper_param.n_epochs = 1000
    hyper_param.learning_rate = 0.001 #1
    hyper_param.learning_rate_decay_ratio = 0.8
    hyper_param.learning_rate_lowerbound = 0.0000
    hyper_param.l1_reg = 0
    hyper_param.l2_reg = 0

    problem_character = train_problem.get_problem_property()

    m = SRLNetwork(problem_character, nn_architecture)

    optimizer = CGDOptimizer()

    trained_batch_num = 0
    valid_freq = 10

    for iter in range(1000):
        for sentence in train_problem.sentences():

            for X, y in train_problem.get_dataset_for_sentence(sentence):

                if trained_batch_num % valid_freq == 0:
                    evaluate(m, valid_problem, trained_batch_num/valid_freq)


                optimizer.batch_size = X.shape[0]
                optimizer.update_chunk(X, y)

                param = optimizer.optimize(m, m.get_parameter())

                m.set_parameter(param)

                trained_batch_num += 1
def test_srl_label_formatter(data_file_path):

    conll05corpora = Conll05Corpora()
    conll05corpora.load(data_file_path)
    print 'load done'

    test_label_file_path = "test_label.txt"
    pred_label_file_path = "pred_label.txt"

    srl_problem = SRLProblem(conll05corpora)

    for valid_sentence in srl_problem.sentences():
        test_labels = []
        pred_labels = []

        test_label_file = open(test_label_file_path, "w")
        pred_label_file = open(pred_label_file_path, "w")

        for srl_x, srl_y in srl_problem.get_dataset_for_sentence(valid_sentence):
            pred_y = [random.choice(SrlTypes.LABEL_SRLTYPE_MAP.keys()) for i in range(srl_y.size)]

            test_labels.append(srl_y)
            pred_labels.append(pred_y)

        test_label_str = srl_problem.pretty_srl_test_label(valid_sentence, test_labels)
        pred_label_str = srl_problem.pretty_srl_predict_label(valid_sentence, pred_labels)

        test_label_file.write(test_label_str)
        pred_label_file.write(pred_label_str)

        test_label_file.close()
        pred_label_file.close()

        try:

            valid_result = eval_srl(test_label_file_path, pred_label_file_path)
        except:
            print "label = "
            print "\n".join([" ".join([SrlTypes.LABEL_SRLTYPE_MAP[x] for x in  label]) for label in pred_labels])
            print "formatted_label = "
            print pred_label_str
def test_srl_problem(data_file_path):

    conll05corpora = Conll05Corpora()
    conll05corpora.load(data_file_path)
    print 'load done'

    srl_problem = SRLProblem(conll05corpora)

    for sentence in srl_problem.sentences():
        for X, y in srl_problem.get_dataset_for_sentence(sentence):
            word_id = [word.id for word in sentence.words()]
            print "word_id = ", word_id
            print "X = ", X[0]

            break


        for srl in sentence.srl_structs():
            print "verb = ", srl.verb.id

        break