from utils import Data
from CharCNN1 import CharCNN1
from CharCNN2 import CharCNN2
from CharTCN import CharTCN

tf.flags.DEFINE_string("m", "CharCNN1", "Select between models CharCNN1, CharCNN2, and CharTCN")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()

if __name__ == "__main__":
    config = json.load(open("config.json"))
    train_data = Data(path=config["data"]["train_path"],
                      input_size=config["data"]["input_size"],
                      vocab=config["data"]["vocab"],
                      num_classes=config["data"]["num_classes"])
    X_train, y_train = train_data.load()
    
    dev_data = Data(path=config["data"]["dev_path"],
                      input_size=config["data"]["input_size"],
                      vocab=config["data"]["vocab"],
                      num_classes=config["data"]["num_classes"])
    X_dev, y_dev = dev_data.load()
    
    if FLAGS.m == "CharCNN1":
        m = CharCNN1(input_size=config["data"]["input_size"],
                 vocab_size=config["data"]["vocab_size"],
                 embedding_size=config["data"]["embedding_size"],
                 num_classes=config["data"]["num_classes"],
                 conv_layers=config["cnn1"]["conv_layers"],
                 fc_layers=config["cnn1"]["fc_layers"],
                 optim_alg=config["cnn1"]["optim_alg"],