def test_deepar_init():
    config = process_config(config_path)

    model = deeparsys.DeepARSysModel(config)
    model.delete_checkpoints()
    create_dirs([
        config.summary_dir, config.checkpoint_dir, config.plots_dir,
        config.output_dir
    ])
    assert os.path.exists(config.summary_dir)
    assert os.path.exists(config.output_dir)
    assert os.path.exists(config.checkpoint_dir)
    data = data_generator.DataGenerator(config)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        logger = Logger(sess, config)
        trainer = deeparsys.DeepARSysTrainer(sess, model, data, config, logger)
        trainer.eval_step()
        trainer.train_step()

        model.load(sess)

        trainer.train()
        trainer.eval_step()
def test_base_model():
    base_config = Bunch(config)
    base_model = base.BaseModel(base_config)
    base_model.init_saver()
    base.BaseModel.gaussian_likelihood(1.)
    create_dirs([base_config.summary_dir, base_config.checkpoint_dir])

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        logger = Logger(sess, base_config)
        logger.summarize(0, summaries_dict={})
Beispiel #3
0
def test_deepar_init():
    config = process_config(config_path)

    create_dirs([config.summary_dir, config.checkpoint_dir])
    model = deepar.DeepARModel(config)
    # model.delete_checkpoints()
    data = data_generator.DataGenerator(config)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        logger = Logger(sess, config)
        trainer = deepar.DeepARTrainer(sess, model, data, config, logger)
        trainer.train()
                os.path.join("experiments", args.experiment.strip()))
            print(config_file)
        else:
            config_file = args.config
        print('ok2')
        config = process_config(config_file)

    except:
        print("missing or invalid arguments")
        exit(0)

    model = deepar.DeepARModel(config)
    data = data_generator.DataGenerator(config)

    create_dirs([
        config.summary_dir, config.checkpoint_dir, config.plots_dir,
        config.output_dir
    ])

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        logger = Logger(sess, config)
        trainer = deepar.DeepARTrainer(sess, model, data, config, logger)

        if trainer.config.from_scratch:
            model.delete_checkpoints()
        model.load(sess)

        trainer.train(verbose=True)
        samples = trainer.sample_on_test()