示例#1
0
        msg = '** Epoch : {0:>2} finished, Train Loss : {1:>6.2}, Train Acc : {2:6.2%}, Time : {3}'
        print_log(msg.format(epoch + 1, total_loss, total_acc, time_diff), file = log)

if __name__ == '__main__':
    # read config
    config = Config.ModelConfig()
    arg = config.arg

    vocab_dict = load_vocab(arg.vocab_path)
    arg.vocab_dict_size = len(vocab_dict)

    if arg.embedding_path:
        embeddings = load_embeddings(arg.embedding_path, vocab_dict)
    else:
        embeddings = init_embeddings(vocab_dict, arg.embedding_size)
    arg.n_vocab, arg.embedding_size = embeddings.shape
    if arg.embedding_normalize:
        embeddings = normalize_embeddings(embeddings)
    arg.n_classes = len(CATEGORIE_ID)

    dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    arg.log_path = 'config/log/log.{}'.format(dt)
    log = open(arg.log_path, 'w')
    print_log('CMD : python3 {0}'.format(' '.join(sys.argv)), file = log)
    print_log('Training with following options :', file = log)
    print_args(arg, log)

    model = Decomposable(arg, export=False)
    train()
    log.close()
示例#2
0
    # read config
    config = Config.ModelConfig()
    arg = config.arg

    vocab_dict = load_vocab(arg.vocab_path)
    arg.vocab_dict_size = len(vocab_dict)
    index2word = {index : word for word, index in vocab_dict.items()}

    if arg.embedding_path:
        embeddings = load_embeddings(arg.embedding_path, vocab_dict)
    else:
        embeddings = init_embeddings(vocab_dict, arg.embedding_size)
    arg.n_vocab, arg.embedding_size = embeddings.shape

    if arg.embedding_normalize:
        embeddings = normalize_embeddings(embeddings)

    arg.n_classes = len(CATEGORIE_ID)

    dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    arg.log_path = 'config/log/log.{}'.format(dt)
    log = open(arg.log_path, 'w')
    print_log('CMD : python3 {0}'.format(' '.join(sys.argv)), file=log)
    print_log('Testing with following options :', file=log)
    print_args(arg, log)

    model = Decomposable(arg.seq_length, arg.n_vocab, arg.embedding_size, arg.hidden_size, arg.attention_size, arg.n_classes, \
                 arg.batch_size, arg.learning_rate, arg.optimizer, arg.l2, arg.clip_value)
    predict()
    log.close()