def testBadArguments(self): initial_state = (tf.random_normal([BATCH_SIZE, 6]), (tf.random_normal([BATCH_SIZE, 7]), tf.random_normal([BATCH_SIZE, 8]))) with self.assertRaises(TypeError): snt.TrainableInitialState(initial_state, mask=(True, (False, "foo"))) with self.assertRaises(tf.errors.InvalidArgumentError): snt.TrainableInitialState(initial_state, mask=(True, (False, True)))() # Check that the class checks that the elements of initial_state have # identical rows. self.evaluate(tf.global_variables_initializer())
def testInitialStateComputation(self, tuple_state, mask): if tuple_state: initial_state = (tf.fill([BATCH_SIZE, 6], 2), (tf.fill([BATCH_SIZE, 7], 3), tf.fill([BATCH_SIZE, 8], 4))) else: initial_state = tf.fill([BATCH_SIZE, 9], 10) trainable_state_module = snt.TrainableInitialState(initial_state, mask=mask) trainable_state = trainable_state_module() flat_trainable_state = nest.flatten(trainable_state) nest.assert_same_structure(initial_state, trainable_state) flat_initial_state = nest.flatten(initial_state) if mask is not None: flat_mask = nest.flatten(mask) else: flat_mask = (True, ) * len(flat_initial_state) self.evaluate(tf.global_variables_initializer()) # Check all variables are initialized correctly and return a state that # has the same as it is provided. for trainable_state, initial_state in zip(flat_trainable_state, flat_initial_state): self.assertAllEqual(self.evaluate(trainable_state), self.evaluate(initial_state)) # Change the value of all the trainable variables to ones. for variable in tf.trainable_variables(): self.evaluate(tf.assign(variable, tf.ones_like(variable))) # In eager mode to re-evaluate the module we must re-connect it. trainable_state = trainable_state_module() flat_trainable_state = nest.flatten(trainable_state) # Check that the values of the initial_states have changed if and only if # they are trainable. for trainable_state, initial_state, mask in zip( flat_trainable_state, flat_initial_state, flat_mask): trainable_state_value = self.evaluate(trainable_state) initial_state_value = self.evaluate(initial_state) if mask: expected_value = np.ones_like(initial_state_value) else: expected_value = initial_state_value self.assertAllEqual(trainable_state_value, expected_value)