Ejemplo n.º 1
0
if PERSIST:
    pickle.dump(word_indices, open(best_model_word_indices(), 'wb'))

if FINAL:
    print("\n > running in FINAL mode!\n")
    training, testing = loader.load_final()
else:
    training, validation, testing = loader.load_train_val_test()

if POST_MORTEM:
    print("\n > running in Post-Mortem mode!\n")
    gold_data = SemEvalDataLoader().get_gold(task=TASK)
    gX = [obs[1] for obs in gold_data]
    gy = [obs[0] for obs in gold_data]
    gold = prepare_dataset(gX, gy, loader.pipeline, loader.y_one_hot)

    validation = testing
    testing = gold
    FINAL = False

############################################################################
# NN MODEL
############################################################################
print("Building NN Model...")
nn_model = target_RNN(embeddings,
                      tweet_max_length=text_max_length,
                      aspect_max_length=target_max_length,
                      noise=0.2,
                      activity_l2=0.001,
                      drop_text_rnn_U=0.2,
if FINAL:
    print("\n > running in FINAL mode!\n")
    training, testing = loader.load_final()
else:
    training, validation, testing = loader.load_train_val_test()

if SEMEVAL_GOLD:
    print("\n > running in Post-Mortem mode!\n")
    gold_data = SemEval2017Task6().get_gold_data_task_1()
    gold_data = [v for k, v in sorted(gold_data.items())]
    X = [x for hashtag in gold_data for x in hashtag[0]]
    y = [x for hashtag in gold_data for x in hashtag[1]]
    gold = prepare_dataset(X,
                           y,
                           loader.pipeline,
                           loader.y_one_hot,
                           y_as_is=loader.subtask == "2")

    validation = testing
    testing = gold
    FINAL = False

print("Building NN Model...")
nn_model = humor_RNN(embeddings, text_length)
# nn_model = humor_CNN(embeddings, text_length)
# nn_model = humor_FFNN(embeddings, text_length)
plot(nn_model,
     show_layer_names=True,
     show_shapes=True,
     to_file="model_task6_sub{}.png".format(TASK))