def test_tas_for_tensors(self): a = tf.reshape(tf.range(20), [5, 4]) tensors = [a, (a, ExampleTuple(a, a))] tas = nested_utils.tas_for_tensors(tensors, 5) nest.assert_same_structure(tensors, tas) # We can't pass TensorArrays to sess.run so instead we turn then back into # tensors to check that they were created correctly. stacked = nested_utils.map_nested(lambda x: x.stack(), tas) with self.test_session() as sess: gt, out = sess.run([tensors, stacked]) gt = nest.flatten(gt) out = nest.flatten(out) # Check that the tas were created correctly. for x, y in zip(gt, out): self.assertAllClose(x, y)
def test_read_tas(self): a = tf.reshape(tf.range(20), [5, 4]) a_read = a[3, :] tensors = [a, (a, ExampleTuple(a, a))] gt_read = [a_read, (a_read, ExampleTuple(a_read, a_read))] tas = nested_utils.tas_for_tensors(tensors, 5) tas_read = nested_utils.read_tas(tas, 3) nest.assert_same_structure(tas, tas_read) with self.test_session() as sess: gt, out = sess.run([gt_read, tas_read]) gt = nest.flatten(gt) out = nest.flatten(out) # Check that the tas were read correctly. for x, y in zip(gt, out): self.assertAllClose(x, y)
def test_tas_for_tensors(self): a = tf.reshape(tf.range(20), [5, 4]) tensors = [a, (a, ExampleTuple(a, a))] tas = nested_utils.tas_for_tensors(tensors, 5) nest.assert_same_structure(tensors, tas) # We can't pass TensorArrays to sess.run so instead we turn then back into # tensors to check that they were created correctly. stacked = nested_utils.map_nested(lambda x: x.stack(), tas) with self.test_session() as sess: gt, out = sess.run([tensors, stacked]) gt = nest.flatten(gt) out = nest.flatten(out) # Check that the tas were created correctly. for x, y in zip(gt, out): self.assertAllClose(x, y)
def test_read_tas(self): a = tf.reshape(tf.range(20), [5, 4]) a_read = a[3, :] tensors = [a, (a, ExampleTuple(a, a))] gt_read = [a_read, (a_read, ExampleTuple(a_read, a_read))] tas = nested_utils.tas_for_tensors(tensors, 5) tas_read = nested_utils.read_tas(tas, 3) nest.assert_same_structure(tas, tas_read) with self.test_session() as sess: gt, out = sess.run([gt_read, tas_read]) gt = nest.flatten(gt) out = nest.flatten(out) # Check that the tas were read correctly. for x, y in zip(gt, out): self.assertAllClose(x, y)
def set_observations(self, observations, seq_lengths): """Sets the observations for the model. This method provides the model with all observed variables including both inputs and targets. It will be called before running any computations with the model that require the observations, e.g. training the model or computing bounds, and should be used to run any necessary preprocessing steps. Args: observations: A potentially nested set of Tensors containing all observations for the model, both inputs and targets. Typically a set of Tensors with shape [max_seq_len, batch_size, data_size]. seq_lengths: A [batch_size] Tensor of ints encoding the length of each sequence in the batch (sequences can be padded to a common length). """ self.observations = observations self.max_seq_len = tf.reduce_max(seq_lengths) self.observations_ta = nested.tas_for_tensors( observations, self.max_seq_len, clear_after_read=False) self.seq_lengths = seq_lengths
def set_observations(self, observations, seq_lengths): """Sets the observations for the model. This method provides the model with all observed variables including both inputs and targets. It will be called before running any computations with the model that require the observations, e.g. training the model or computing bounds, and should be used to run any necessary preprocessing steps. Args: observations: A potentially nested set of Tensors containing all observations for the model, both inputs and targets. Typically a set of Tensors with shape [max_seq_len, batch_size, data_size]. seq_lengths: A [batch_size] Tensor of ints encoding the length of each sequence in the batch (sequences can be padded to a common length). """ self.observations = observations self.max_seq_len = tf.reduce_max(seq_lengths) self.observations_ta = nested.tas_for_tensors( observations, self.max_seq_len, clear_after_read=False) self.seq_lengths = seq_lengths