コード例 #1
0
def train(params):
    global checkpoint_dir, ckpt, model
    assert params["mode"].lower() == "train", "change training mode to 'train'"

    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    print('true vocab is ', vocab)

    print("Creating the batcher ...")
    b = batcher(vocab, params)  # 下次PGN详细讲,仅仅是将数据封装成tf特定的格式
    print("Building the model ...")
    if params["model"] == "SequenceToSequence":
        model = SequenceToSequence(params)
    # elif params["model"] == "PGN":
    #     model = PGN(params)

    print("Creating the checkpoint manager")
    if params["model"] == "SequenceToSequence":
        checkpoint_dir = "{}/checkpoint".format(params["seq2seq_model_dir"])
        ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                   SequenceToSequence=model)
    # elif params["model"] == "PGN":
    #     checkpoint_dir = "{}/checkpoint".format(params["pgn_model_dir"])
    #     ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_dir,
                                              max_to_keep=5)

    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    print("Starting the training ...")
    train_model(model, b, params, ckpt_manager)
コード例 #2
0
def train(params):
    assert params["mode"].lower() == "train", "change training mode to 'train'"

    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    print('true vocab is ', vocab)

    print("Creating the batcher ...")
    b = batcher(vocab, params)

    print("Building the model ...")
    # model = SequenceToSequence(params)
    # model = PGN(params)
    model = PGN_TRANSFORMER(params)

    print("Creating the checkpoint manager")
    checkpoint_dir = "{}/checkpoint".format(params["seq2seq_model_dir"])
    # 防止训练过程中中断,不需要重新训练
    ckpt = tf.train.Checkpoint(SequenceToSequence=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)

    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    print("Starting the training ...")
    train_model(model, b, params, ckpt, ckpt_manager)
コード例 #3
0
def train_and_test(data_dir, no_of_epochs=4):
    results_dir = data_dir + '/results'
    batch_size = 64
    vgg = Vgg16()

    train_model(vgg, data_dir, batch_size, no_of_epochs)

    batches, preds = test_model(vgg, data_dir + '/test', batch_size=batch_size)

    save_array(results_dir + '/test_preds', preds)
    save_array(results_dir + '/filenames', batches.filenames)

    return batches, preds, vgg
コード例 #4
0
ファイル: train.py プロジェクト: KindRoach/DeepCoNN-Pytorch
import pandas
import torch

from model.DeepCoNN import DeepCoNNConfig, DeepCoNN
from utils.data_reader import get_train_dev_test_data
from utils.train_helper import train_model
from utils.word2vec_hepler import WORD_EMBEDDING_SIZE, load_embedding_weights

train_data, dev_data, test_data = get_train_dev_test_data()
know_data = pandas.concat([train_data, dev_data])

config = DeepCoNNConfig(
    num_epochs=50,
    batch_size=16,
    learning_rate=1e-3,
    l2_regularization=1e-2,
    learning_rate_decay=0.95,
    device="cuda:0" if torch.cuda.is_available() else "cpu",
    max_review_length=
    2048,  # Make sure this value is smaller than max_length in data_reader.py
    word_dim=WORD_EMBEDDING_SIZE,
    kernel_widths=[2, 3, 5, 7],
    kernel_deep=100,
    latent_factors=50,
    fm_k=8)

model = DeepCoNN(config, load_embedding_weights())
train_model(model, train_data, dev_data)