예제 #1
0
  def testBadArguments(self):
    initial_state = (tf.random_normal([BATCH_SIZE, 6]),
                     (tf.random_normal([BATCH_SIZE, 7]),
                      tf.random_normal([BATCH_SIZE, 8])))
    with self.assertRaises(TypeError):
      snt.TrainableInitialState(initial_state, mask=(True, (False, "foo")))

    with self.assertRaises(tf.errors.InvalidArgumentError):
      snt.TrainableInitialState(initial_state, mask=(True, (False, True)))()
      # Check that the class checks that the elements of initial_state have
      # identical rows.
      self.evaluate(tf.global_variables_initializer())
예제 #2
0
    def testInitialStateComputation(self, tuple_state, mask):
        if tuple_state:
            initial_state = (tf.fill([BATCH_SIZE, 6],
                                     2), (tf.fill([BATCH_SIZE, 7], 3),
                                          tf.fill([BATCH_SIZE, 8], 4)))
        else:
            initial_state = tf.fill([BATCH_SIZE, 9], 10)

        trainable_state_module = snt.TrainableInitialState(initial_state,
                                                           mask=mask)
        trainable_state = trainable_state_module()
        flat_trainable_state = nest.flatten(trainable_state)
        nest.assert_same_structure(initial_state, trainable_state)
        flat_initial_state = nest.flatten(initial_state)
        if mask is not None:
            flat_mask = nest.flatten(mask)
        else:
            flat_mask = (True, ) * len(flat_initial_state)

        self.evaluate(tf.global_variables_initializer())

        # Check all variables are initialized correctly and return a state that
        # has the same as it is provided.
        for trainable_state, initial_state in zip(flat_trainable_state,
                                                  flat_initial_state):
            self.assertAllEqual(self.evaluate(trainable_state),
                                self.evaluate(initial_state))

        # Change the value of all the trainable variables to ones.
        for variable in tf.trainable_variables():
            self.evaluate(tf.assign(variable, tf.ones_like(variable)))

        # In eager mode to re-evaluate the module we must re-connect it.
        trainable_state = trainable_state_module()
        flat_trainable_state = nest.flatten(trainable_state)

        # Check that the values of the initial_states have changed if and only if
        # they are trainable.
        for trainable_state, initial_state, mask in zip(
                flat_trainable_state, flat_initial_state, flat_mask):
            trainable_state_value = self.evaluate(trainable_state)
            initial_state_value = self.evaluate(initial_state)
            if mask:
                expected_value = np.ones_like(initial_state_value)
            else:
                expected_value = initial_state_value

            self.assertAllEqual(trainable_state_value, expected_value)