def test_config():
    try:
        get_config_file('tests')
    except SystemExit:
        print("it's okay, there are probably several config files")
    config = process_config(config_file)
    print(config.cell_args)
def test_deepar_init():
    config = process_config(config_path)

    model = deeparsys.DeepARSysModel(config)
    model.delete_checkpoints()
    create_dirs([
        config.summary_dir, config.checkpoint_dir, config.plots_dir,
        config.output_dir
    ])
    assert os.path.exists(config.summary_dir)
    assert os.path.exists(config.output_dir)
    assert os.path.exists(config.checkpoint_dir)
    data = data_generator.DataGenerator(config)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        logger = Logger(sess, config)
        trainer = deeparsys.DeepARSysTrainer(sess, model, data, config, logger)
        trainer.eval_step()
        trainer.train_step()

        model.load(sess)

        trainer.train()
        trainer.eval_step()
def test_data_config_update():
    config = process_config(config_path_2)
    data = data_generator.DataGenerator(config)

    config = data.update_config()
    assert 'num_cov' in config
    assert 'num_features' in config
    assert 'num_ts' in config
    assert config.batch_size == config.num_ts
def test_data():
    config = process_config(config_path)
    data = data_generator.DataGenerator(config)

    batch_Z, batch_X = next(data.next_batch(config.batch_size))
    assert batch_Z.shape[0] == config.batch_size == batch_X.shape[0]
    assert batch_Z.shape[
        1] == config.cond_length + config.pred_length == batch_X.shape[1]

    assert data.Z.shape[0] == data.X.shape[0]
def test_splitting_config():
    config = process_config(config_file)

    config_list = split_grid_config(config)

    assert isinstance(config_list, list)
    assert len(config_list)
    for c in config_list:
        for k, v in c.items():
            print(k, v)
            assert not isinstance(v, list)
Exemple #6
0
def test_deepar_init():
    config = process_config(config_path)

    create_dirs([config.summary_dir, config.checkpoint_dir])
    model = deepar.DeepARModel(config)
    # model.delete_checkpoints()
    data = data_generator.DataGenerator(config)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        logger = Logger(sess, config)
        trainer = deepar.DeepARTrainer(sess, model, data, config, logger)
        trainer.train()
if __name__ == '__main__':
    # config_path = os.path.join('deepartransit','experiments', 'deepar_dev','deepar_config.yml')

    try:
        args = get_args()
        print(args.experiment)
        if args.experiment:
            print('ok')
            config_file = get_config_file(
                os.path.join("experiments", args.experiment.strip()))
            print(config_file)
        else:
            config_file = args.config
        print('ok2')
        config = process_config(config_file)

    except:
        print("missing or invalid arguments")
        exit(0)

    model = deepar.DeepARModel(config)
    data = data_generator.DataGenerator(config)

    create_dirs([
        config.summary_dir, config.checkpoint_dir, config.plots_dir,
        config.output_dir
    ])

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
import pandas as pd
from timeit import default_timer as timer

if __name__ == '__main__':
    # config_path = os.path.join('deepartransit','experiments', 'deeparsys_dev','deeparsys_config.yml')
    try:
        args = get_args()
        print(args.experiment)
        if args.experiment:
            print('found an experiment argument:', args.experiment)
            meta_config_file = get_config_file(os.path.join("experiments", args.experiment))
            print("which constains a config file", meta_config_file)
        else:
            meta_config_file = args.config
        print('processing the config from the config file')
        meta_config = process_config(meta_config_file)

    except:
        print("missing or invalid arguments")
        exit(0)
    grid_config = split_grid_config(meta_config)
    list_configs = [c for c in grid_config
                    if c['total_length'] == c['pretrans_length'] + c['trans_length'] + c['postrans_length']]

    df_scores = pd.DataFrame(index=list(range(len(list_configs))),
                             columns=list(list_configs[0].keys()) + ['loss_pred', 'nb_epochs', 'mse_pred', 'init_time',
                                                                     'training_time'])
    print('Starting to run {} models'.format(len(list_configs)))
    for i, config in enumerate(list_configs):
        df_scores.loc[i, config.keys()] = list(config.values())
        print('\n\t\t >>>>>>>>> model ', i)