示例#1
0
文件: test_train.py 项目: xz725/tfnet
    def test_traineval(self):
        tf.logging.set_verbosity(tf.logging.INFO)
        path = PATH.decode(sys.stdout.encoding)
        #2 files, 64 epochs, batchsize 32 => 2*64/32 = 4 iterations
        dset = lambda: ds.dataset_with_preprocess(LISTFILE_1, path,
                                          epochs=4,
                                          batchsize=32,
                                          segs_per_sample=16,
                                         )
        dset_eval = lambda: ds.dataset_with_preprocess(LISTFILE_1, path,
                                               epochs=1,
                                               batchsize=16,
                                               segs_per_sample=16,
                                               shuffle=False
                                              )
        config = self.config.replace(save_checkpoints_steps=2)

        tfnet_est = TFNetEstimator(**nets.default_net(), config=config)

        input_fn = lambda: dset().make_one_shot_iterator().get_next()
        eval_input_fn = lambda: dset_eval().make_one_shot_iterator().get_next()

        train_spec = tf.estimator.TrainSpec(input_fn)
        eval_spec = tf.estimator.EvalSpec(eval_input_fn)

        tf.estimator.train_and_evaluate(tfnet_est, train_spec, eval_spec)


        self.assertIsNotNone(tfnet_est)
示例#2
0
def main(_):
    path = PATH.decode(sys.stdout.encoding)
    dset = ds.dataset_with_preprocess(
        LISTFILE_1,
        path,
        epochs=1,
        batchsize=32,
        segs_per_sample=64,
    )
    tfnet_est = TFNetEstimator(**nets.default_net(),
                               model_dir='tests/dummymodel')

    tfnet_est.train(input_fn=lambda: dset.make_one_shot_iterator().get_next())
示例#3
0
文件: test_train.py 项目: xz725/tfnet
    def test_train(self):
        """Runs a training pass with test data for 4 iterations
        iterations of batchsize 16"""
        tf.logging.set_verbosity(tf.logging.INFO)
        path = PATH.decode(sys.stdout.encoding)
        #2 files, 64 epochs, batchsize 32 => 2*64/32 = 4 iterations
        dset = lambda: ds.dataset_with_preprocess(LISTFILE_1, path,
                                          epochs=64,
                                          batchsize=32,
                                         )
        tfnet_est = TFNetEstimator(**nets.default_net(), config=self.config)

        tfnet_est.train(
            input_fn=lambda: dset().make_one_shot_iterator().get_next())

        self.assertIsNotNone(tfnet_est)
示例#4
0
    def test_loadmodel(self):
        """Test running eval from with trained model"""
        tf.logging.set_verbosity(tf.logging.INFO)

        path = PATH.decode(sys.stdout.encoding)
        #2 files, 64 epochs, batchsize 32 => 2*64/32 = 4 iterations
        dset = ds.single_file_dataset(LQ_AUDIO_FILE,
                                     )
        #RunConfig for more more printing since we are only training for very few steps
        config = tf.estimator.RunConfig(log_step_count_steps=1)

        tfnet_est = TFNetEstimator(**nets.default_net(), config=config,
                                   model_dir=DUMMY_MODEL_PATH
                                  )

        preds = tfnet_est.predict(
            input_fn=lambda: dset.make_one_shot_iterator().get_next())

        for pred in preds:
            self.assertEqual(pred.shape, (8192, 1))
示例#5
0
    def test_loadmodel(self):
        """Test running eval from with trained model"""
        tf.logging.set_verbosity(tf.logging.INFO)

        path = PATH.decode(sys.stdout.encoding)
        #2 files, 64 epochs, batchsize 32 => 2*64/32 = 4 iterations
        dset = ds.dataset_with_preprocess(
            LISTFILE_1,
            path,
            epochs=1,
            batchsize=16,
        )
        #RunConfig for more more printing since we are only training for very few steps
        config = tf.estimator.RunConfig(log_step_count_steps=1)

        tfnet_est = TFNetEstimator(**nets.default_net(),
                                   config=config,
                                   model_dir=DUMMY_MODEL_PATH)

        tfnet_est.evaluate(
            input_fn=lambda: dset.make_one_shot_iterator().get_next())

        self.assertIsNotNone(tfnet_est)