コード例 #1
0
import config
import data
import model
import utils
import os
import tensorflow as tf

if __name__ == "__main__":
    # LOAD EMBEDDING
    word_to_index, index_to_word, word_to_vec, emb_matrix = utils.read_glove_vecs(
        os.path.join(config.EMBEDDING_DIR, config.EMBEDDING_PATH))
    print("Pretrained Embedding Loaded")

    #LOAD CONFIG
    train_config = config.TrainConfig()
    test_config = config.TestConfig()
    # LOAD DATA
    train_data = data.DATA(train_config)
    train_data.read_file(config.TRAIN_PATH, word_to_index)
    print("Train data Loaded")
    test_data = data.DATA(test_config)
    test_data.read_file(config.TEST_PATH, word_to_index)
    print("Test data Loaded")

    # BUILD MODEL
    #initializer = tf.random_uniform_initializer(train_config.init_scale, train_config.init_scale)
    with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None):
            train_model = model.MODEL(train_config,
                                      len(word_to_index),
コード例 #2
0
        #     config.data_mean = data_mean
        #     config.data_std = data_std
        #     config.dim_to_ignore = dim_to_ignore
        #     config.dim_to_use = dim_to_use

        #set = utils.normalize_data(set, config.data_mean, config.data_std, list(range(0, 48)))

        self.data = set

    def __getitem__(self, idx):
        if self.config.datatype == 'lie':
            sample = self.data[idx]
        elif self.config.datatype == 'xyz':
            pass
        sample = self.formatdata(sample, False)
        return sample

    def __len__(self):
        return len(self.data)


if __name__ == '__main__':
    import config

    config = config.TrainConfig(dataset='Human',
                                datatype='smpl',
                                action='walking',
                                gpu=[0],
                                training=True,
                                visualize=False)
    data = CMUDataset(config, train=True)
コード例 #3
0
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
import config
import tools
import datautil
from bilstm_seq2seq import BilstmSeq2Seq
import subprocess
import argparse
import tools
import threading
import platform
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_path", help="the path of data", default="./data/")
parser.add_argument("-s", "--save_path", help="the path of the saved model", default="./models/")
parser.add_argument("-e", "--epoch", help="the number of epoch", default=100, type=int)
parser.add_argument("-c", "--char_emb", help="the char embedding file", default="char.emb")
parser.add_argument("-w", "--word_emb", help="the word embedding file", default="term.emb.np")
parser.add_argument("-ed", "--emb_dim", help="the word embedding size", default=128)
parser.add_argument("-hd", "--hid_dim", help="the hidden size", default=128)
parser.add_argument("-g", "--gpu", help="the id of gpu, the default is 0", default=0, type=int)
parser.add_argument("-j", "--job", help="job name.", default="bilstm term-level crf", type=str)

args = parser.parse_args()

cf = config.TrainConfig(args)
if platform.system() == 'Darwin':
    cf.use_gpu = False

class TrainThread(threading.Thread):
    def __init__(self):
        train_iter = datautil.queue_iter()
コード例 #4
0
# create time: 7/20/2019

import numpy as np
import torch
import config
from loss import linearizedlie_loss
import utils
from choose_dataset import DatasetChooser
from torch import autograd
from STLN import ST_HMR, ST_LSTM
import choose_dataset
from torch.utils.data import DataLoader

if __name__ == '__main__':

    config = config.TrainConfig('Human', 'lie', 'all')
    # choose = DatasetChooser(config)
    # prediction_dataset, _ = choose(prediction=True)
    # prediction_loader = DataLoader(prediction_dataset, batch_size=config.batch_size, shuffle=True)

    choose = DatasetChooser(config)
    train_dataset, bone_length = choose(train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    test_dataset, _ = choose(train=False)
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             shuffle=True)
    prediction_dataset, bone_length = choose(prediction=True)
    prediction_loader = DataLoader(prediction_dataset,
コード例 #5
0
        type=str,
        default='all',
        dest="action",
        help="choose one action in the dataset:"
        "h3.6m_actions = ['directions', 'discussion', 'eating', 'greeting', 'phoning', 'posing', 'purchases', 'sitting',"
        "'sittingdown', 'smoking', 'takingphoto', 'waiting', 'walking', 'walkingdog', 'walkingtogether']"
        "'all means all of the above")
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        dest="dataset",
                        help="choose dataset from 'Human' or 'Mouse'")
    parser.add_argument("--datatype",
                        type=str,
                        default='smpl',
                        dest="datatype",
                        help="only lie is usable")
    parser.add_argument("--visualize",
                        type=bool,
                        default=False,
                        dest="visualize",
                        help="visualize the prediction or not ")
    args = parser.parse_args()
    config = config.TrainConfig(args.dataset, args.datatype, args.action,
                                args.gpu, args.training, args.visualize)
    checkpoint_dir, output_dir = utils.create_directory(config)
    if config.train_model is True:
        train(config, checkpoint_dir)
    else:
        prediction(config, checkpoint_dir, output_dir)
コード例 #6
0
from model import LightningModel


def init_trainer():
    """ Init a Lightning Trainer using from_argparse_args
    Thus every CLI command (--gpus, distributed_backend, ...) become available.
    """
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    lr_logger = LearningRateMonitor()
    early_stopping = EarlyStopping(monitor='val_loss',
                                   mode='min',
                                   min_delta=0.001,
                                   patience=10,
                                   verbose=True)
    return Trainer.from_argparse_args(args,
                                      callbacks=[lr_logger, early_stopping])


def run_training(config):
    """ Instanciate a datamodule, a model and a trainer and run trainer.fit(model, data) """
    data = CIFAR10DataModule(config.rootdir)
    model = LightningModel(config)
    trainer = init_trainer()
    trainer.fit(model, data)


if __name__ == '__main__':
    run_training(cfg.TrainConfig())