def testLearnSineFunction(self):
        """Tests that `_MultiValueRNNEstimator` can learn a sine function."""
        batch_size = 8
        sequence_length = 64
        train_steps = 200
        eval_steps = 20
        cell_size = 4
        learning_rate = 0.1
        loss_threshold = 0.02

        def get_sin_input_fn(batch_size, sequence_length, increment, seed=None):
            def _sin_fn(x):
                ranger = tf.linspace(tf.reshape(x[0], []), (sequence_length - 1) * increment, sequence_length + 1)
                return tf.sin(ranger)

            def input_fn():
                starts = tf.random_uniform([batch_size], maxval=(2 * np.pi), seed=seed)
                sin_curves = tf.map_fn(_sin_fn, (starts,), dtype=tf.float32)
                inputs = tf.expand_dims(tf.slice(sin_curves, [0, 0], [batch_size, sequence_length]), 2)
                labels = tf.slice(sin_curves, [0, 1], [batch_size, sequence_length])
                return {"inputs": inputs}, labels

            return input_fn

        config = tf.contrib.learn.RunConfig(tf_random_seed=1234)
        sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor(
            num_units=cell_size, learning_rate=learning_rate, config=config
        )

        train_input_fn = get_sin_input_fn(batch_size, sequence_length, np.pi / 32, seed=1234)
        eval_input_fn = get_sin_input_fn(batch_size, sequence_length, np.pi / 32, seed=4321)

        sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
        loss = sequence_estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)["loss"]
        self.assertLess(loss, loss_threshold, "Loss should be less than {}; got {}".format(loss_threshold, loss))
Example #2
0
    def testLearnSineFunction(self):
        """Tests that `_MultiValueRNNEstimator` can learn a sine function."""
        batch_size = 8
        sequence_length = 64
        train_steps = 200
        eval_steps = 20
        cell_size = 4
        learning_rate = 0.1
        loss_threshold = 0.02

        def get_sin_input_fn(batch_size,
                             sequence_length,
                             increment,
                             seed=None):
            def _sin_fn(x):
                ranger = tf.linspace(tf.reshape(x[0], []),
                                     (sequence_length - 1) * increment,
                                     sequence_length + 1)
                return tf.sin(ranger)

            def input_fn():
                starts = tf.random_uniform([batch_size],
                                           maxval=(2 * np.pi),
                                           seed=seed)
                sin_curves = tf.map_fn(_sin_fn, (starts, ), dtype=tf.float32)
                inputs = tf.expand_dims(
                    tf.slice(sin_curves, [0, 0],
                             [batch_size, sequence_length]), 2)
                labels = tf.slice(sin_curves, [0, 1],
                                  [batch_size, sequence_length])
                return {'inputs': inputs}, labels

            return input_fn

        seq_columns = [
            tf.contrib.layers.real_valued_column('inputs', dimension=cell_size)
        ]
        config = tf.contrib.learn.RunConfig(tf_random_seed=1234)
        sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor(
            num_units=cell_size,
            sequence_feature_columns=seq_columns,
            learning_rate=learning_rate,
            config=config)

        train_input_fn = get_sin_input_fn(batch_size,
                                          sequence_length,
                                          np.pi / 32,
                                          seed=1234)
        eval_input_fn = get_sin_input_fn(batch_size,
                                         sequence_length,
                                         np.pi / 32,
                                         seed=4321)

        sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
        loss = sequence_estimator.evaluate(input_fn=eval_input_fn,
                                           steps=eval_steps)['loss']
        self.assertLess(
            loss, loss_threshold,
            'Loss should be less than {}; got {}'.format(loss_threshold, loss))
  def testLearnSineFunction(self):
    """Tests learning a sine function."""
    batch_size = 8
    sequence_length = 64
    train_steps = 200
    eval_steps = 20
    cell_size = 4
    learning_rate = 0.1
    loss_threshold = 0.02

    def get_sin_input_fn(batch_size, sequence_length, increment, seed=None):

      def _sin_fn(x):
        ranger = math_ops.linspace(
            array_ops.reshape(x[0], []), (sequence_length - 1) * increment,
            sequence_length + 1)
        return math_ops.sin(ranger)

      def input_fn():
        starts = random_ops.random_uniform(
            [batch_size], maxval=(2 * np.pi), seed=seed)
        sin_curves = functional_ops.map_fn(
            _sin_fn, (starts,), dtype=dtypes.float32)
        inputs = array_ops.expand_dims(
            array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]),
            2)
        labels = array_ops.slice(sin_curves, [0, 1],
                                 [batch_size, sequence_length])
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=1234)
    sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor(
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        input_keep_probability=0.9,
        output_keep_probability=0.9,
        config=config)

    train_input_fn = get_sin_input_fn(
        batch_size, sequence_length, np.pi / 32, seed=1234)
    eval_input_fn = get_sin_input_fn(
        batch_size, sequence_length, np.pi / 32, seed=4321)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
    loss = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)['loss']
    self.assertLess(loss, loss_threshold,
                    'Loss should be less than {}; got {}'.format(loss_threshold,
                                                                 loss))