Exemplo n.º 1
0
  def testFlattenAndPack(self):
    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
    self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
    self.assertEqual(
        nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
                                                 ("d", "e", ("f", "g"), "h")))
    point = collections.namedtuple("Point", ["x", "y"])
    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
    flat = [4, 2, 1, 0]
    self.assertEqual(nest.flatten(structure), flat)
    restructured_from_flat = nest.pack_sequence_as(structure, flat)
    self.assertEqual(restructured_from_flat, structure)
    self.assertEqual(restructured_from_flat[0].x, 4)
    self.assertEqual(restructured_from_flat[0].y, 2)
    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
    self.assertEqual(restructured_from_flat[1][0][0].y, 0)

    self.assertEqual([5], nest.flatten(5))
    self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

    self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
    self.assertEqual(
        np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

    with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
      nest.pack_sequence_as("scalar", [4, 5])

    with self.assertRaisesRegexp(TypeError, "flat_sequence"):
      nest.pack_sequence_as([4, 5], "bad_sequence")

    with self.assertRaises(ValueError):
      nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
Exemplo n.º 2
0
    def testFlattenAndPack(self):
        structure = ((3, 4), 5, (6, 7, (9, 10), 8))
        flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
        self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
        self.assertEqual(nest.pack_sequence_as(structure, flat),
                         (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
        point = collections.namedtuple("Point", ["x", "y"])
        structure = (point(x=4, y=2), ((point(x=1, y=0), ), ))
        flat = [4, 2, 1, 0]
        self.assertEqual(nest.flatten(structure), flat)
        restructured_from_flat = nest.pack_sequence_as(structure, flat)
        self.assertEqual(restructured_from_flat, structure)
        self.assertEqual(restructured_from_flat[0].x, 4)
        self.assertEqual(restructured_from_flat[0].y, 2)
        self.assertEqual(restructured_from_flat[1][0][0].x, 1)
        self.assertEqual(restructured_from_flat[1][0][0].y, 0)

        self.assertEqual([5], nest.flatten(5))
        self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

        self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
        self.assertEqual(np.array([5]),
                         nest.pack_sequence_as("scalar", [np.array([5])]))

        with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
            nest.pack_sequence_as("scalar", [4, 5])

        with self.assertRaisesRegexp(TypeError, "flat_sequence"):
            nest.pack_sequence_as([4, 5], "bad_sequence")

        with self.assertRaises(ValueError):
            nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
Exemplo n.º 3
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.contrib.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Args:
      dataset: A `tf.contrib.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.in_eager_mode():
      raise RuntimeError(
          "{} objects only make sense when eager execution is enabled".format(
              type(self)))
    ds_variant = dataset.make_dataset_resource()
    self._output_types = dataset.output_types
    self._flat_output_types = nest.flatten(dataset.output_types)
    self._flat_output_shapes = nest.flatten(dataset.output_shapes)
    self._resource = gen_dataset_ops.iterator(
        container="",
        shared_name=_iterator_shared_name(),
        output_types=self._flat_output_types,
        output_shapes=self._flat_output_shapes)
    gen_dataset_ops.make_iterator(ds_variant, self._resource)
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.contrib.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Args:
      dataset: A `tf.contrib.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.in_eager_mode():
            raise RuntimeError(
                "{} objects only make sense when eager execution is enabled".
                format(type(self)))
        ds_variant = dataset.make_dataset_resource()
        self._output_types = dataset.output_types
        self._flat_output_types = nest.flatten(dataset.output_types)
        self._flat_output_shapes = nest.flatten(dataset.output_shapes)
        self._resource = gen_dataset_ops.iterator(
            container="",
            shared_name=_iterator_shared_name(),
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)
        gen_dataset_ops.make_iterator(ds_variant, self._resource)
Exemplo n.º 5
0
 def testFlattenDictOrder(self):
   """`flatten` orders dicts by key, including OrderedDicts."""
   ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
   plain = {"d": 3, "b": 1, "a": 0, "c": 2}
   ordered_flat = nest.flatten(ordered)
   plain_flat = nest.flatten(plain)
   self.assertEqual([0, 1, 2, 3], ordered_flat)
   self.assertEqual([0, 1, 2, 3], plain_flat)
Exemplo n.º 6
0
 def make_dataset_resource(self):
   return gen_dataset_ops.sloppy_interleave_dataset(
       self._input_dataset.make_dataset_resource(),
       self._map_func.captured_inputs,
       self._cycle_length,
       self._block_length,
       f=self._map_func,
       output_types=nest.flatten(self.output_types),
       output_shapes=nest.flatten(self.output_shapes))
Exemplo n.º 7
0
  def __init__(self, input_dataset, map_func, cycle_length, block_length):
    """See `tf.contrib.data.sloppy_interleave()` for details."""
    super(SloppyInterleaveDataset, self).__init__()
    self._input_dataset = input_dataset

    @function.Defun(*nest.flatten(input_dataset.output_types))
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)

      if nest.is_sequence(nested_args):
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset.make_dataset_resource()

    self._map_func = tf_map_func
    self._map_func.add_to_graph(ops.get_default_graph())

    self._cycle_length = ops.convert_to_tensor(
        cycle_length, dtype=dtypes.int64, name="cycle_length")
    self._block_length = ops.convert_to_tensor(
        block_length, dtype=dtypes.int64, name="block_length")
  def testConcatenateDatasetDifferentShape(self):
    input_components = (
        np.tile(np.array([[1], [2], [3], [4]]), 20),
        np.tile(np.array([[12], [13], [14], [15]]), 4))
    to_concatenate_components = (
        np.tile(np.array([[1], [2], [3], [4], [5]]), 20),
        np.tile(np.array([[12], [13], [14], [15], [16]]), 15))

    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
        to_concatenate_components)
    concatenated = input_dataset.concatenate(dataset_to_concatenate)
    self.assertEqual(
        [ts.as_list()
         for ts in nest.flatten(concatenated.output_shapes)], [[20], [None]])

    iterator = concatenated.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.test_session() as sess:
      sess.run(init_op)
      for i in range(9):
        result = sess.run(get_next)
        if i < 4:
          for component, result_component in zip(input_components, result):
            self.assertAllEqual(component[i], result_component)
        else:
          for component, result_component in zip(to_concatenate_components,
                                                 result):
            self.assertAllEqual(component[i - 4], result_component)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
Exemplo n.º 9
0
    def testConcatenateDatasetDifferentShape(self):
        input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
                            np.tile(np.array([[12], [13], [14], [15]]), 4))
        to_concatenate_components = (np.tile(
            np.array([[1], [2], [3], [4], [5]]),
            20), np.tile(np.array([[12], [13], [14], [15], [16]]), 15))

        input_dataset = dataset_ops.Dataset.from_tensor_slices(
            input_components)
        dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
            to_concatenate_components)
        concatenated = input_dataset.concatenate(dataset_to_concatenate)
        self.assertEqual(
            [ts.as_list() for ts in nest.flatten(concatenated.output_shapes)],
            [[20], [None]])

        iterator = concatenated.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(9):
                result = sess.run(get_next)
                if i < 4:
                    for component, result_component in zip(
                            input_components, result):
                        self.assertAllEqual(component[i], result_component)
                else:
                    for component, result_component in zip(
                            to_concatenate_components, result):
                        self.assertAllEqual(component[i - 4], result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)
    def testRestructureDataset(self):
        components = (array_ops.placeholder(dtypes.int32),
                      (array_ops.placeholder(dtypes.int32, shape=[None]),
                       array_ops.placeholder(dtypes.int32, shape=[20, 30])))
        dataset = dataset_ops.Dataset.from_tensors(components)

        i32 = dtypes.int32

        test_cases = [((i32, i32, i32), None), (((i32, i32), i32), None),
                      ((i32, i32, i32), (None, None, None)),
                      ((i32, i32, i32), ([17], [17], [20, 30]))]

        for new_types, new_shape_lists in test_cases:
            # pylint: disable=protected-access
            new = dataset_ops._RestructuredDataset(dataset, new_types,
                                                   new_shape_lists)
            # pylint: enable=protected-access
            self.assertEqual(new_types, new.output_types)
            if new_shape_lists is not None:
                for expected_shape_list, shape in zip(
                        nest.flatten(new_shape_lists),
                        nest.flatten(new.output_shapes)):
                    if expected_shape_list is None:
                        self.assertIs(None, shape.ndims)
                    else:
                        self.assertEqual(expected_shape_list, shape.as_list())

        fail_cases = [((i32, dtypes.int64, i32), None),
                      ((i32, i32, i32, i32), None),
                      ((i32, i32, i32), ((None, None), None)),
                      ((i32, i32, i32), (None, None, None, None)),
                      ((i32, i32, i32), (None, [None], [21, 30]))]

        for new_types, new_shape_lists in fail_cases:
            with self.assertRaises(ValueError):
                # pylint: disable=protected-access
                new = dataset_ops._RestructuredDataset(dataset, new_types,
                                                       new_shape_lists)
  def testRestructureDataset(self):
    components = (array_ops.placeholder(dtypes.int32),
                  (array_ops.placeholder(dtypes.int32, shape=[None]),
                   array_ops.placeholder(dtypes.int32, shape=[20, 30])))
    dataset = dataset_ops.Dataset.from_tensors(components)

    i32 = dtypes.int32

    test_cases = [((i32, i32, i32), None),
                  (((i32, i32), i32), None),
                  ((i32, i32, i32), (None, None, None)),
                  ((i32, i32, i32), ([17], [17], [20, 30]))]

    for new_types, new_shape_lists in test_cases:
      # pylint: disable=protected-access
      new = dataset_ops._RestructuredDataset(
          dataset, new_types, new_shape_lists)
      # pylint: enable=protected-access
      self.assertEqual(new_types, new.output_types)
      if new_shape_lists is not None:
        for expected_shape_list, shape in zip(
            nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
          if expected_shape_list is None:
            self.assertIs(None, shape.ndims)
          else:
            self.assertEqual(expected_shape_list, shape.as_list())

    fail_cases = [((i32, dtypes.int64, i32), None),
                  ((i32, i32, i32, i32), None),
                  ((i32, i32, i32), ((None, None), None)),
                  ((i32, i32, i32), (None, None, None, None)),
                  ((i32, i32, i32), (None, [None], [21, 30]))]

    for new_types, new_shape_lists in fail_cases:
      with self.assertRaises(ValueError):
        # pylint: disable=protected-access
        new = dataset_ops._RestructuredDataset(
            dataset, new_types, new_shape_lists)
Exemplo n.º 12
0
  def testFlattenAndPack_withDicts(self):
    # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
    named_tuple = collections.namedtuple("A", ("b", "c"))
    mess = (
        "z",
        named_tuple(3, 4),
        {
            "c": (
                1,
                collections.OrderedDict([
                    ("b", 3),
                    ("a", 2),
                ]),
            ),
            "b": 5
        },
        17
    )

    flattened = nest.flatten(mess)
    self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])

    structure_of_mess = (
        14,
        named_tuple("a", True),
        {
            "c": (
                0,
                collections.OrderedDict([
                    ("b", 9),
                    ("a", 8),
                ]),
            ),
            "b": 3
        },
        "hi everybody",
    )

    unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
    self.assertEqual(unflattened, mess)

    # Check also that the OrderedDict was created, with the correct key order.
    unflattened_ordered_dict = unflattened[2]["c"][1]
    self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
    self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
Exemplo n.º 13
0
    def testMapStructure(self):
        structure1 = (((1, 2), 3), 4, (5, 6))
        structure2 = (((7, 8), 9), 10, (11, 12))
        structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
        nest.assert_same_structure(structure1, structure1_plus1)
        self.assertAllEqual([2, 3, 4, 5, 6, 7], nest.flatten(structure1_plus1))
        structure1_plus_structure2 = nest.map_structure(
            lambda x, y: x + y, structure1, structure2)
        self.assertEqual((((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
                         structure1_plus_structure2)

        self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

        self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

        with self.assertRaisesRegexp(TypeError, "callable"):
            nest.map_structure("bad", structure1_plus1)

        with self.assertRaisesRegexp(ValueError, "same nested structure"):
            nest.map_structure(lambda x, y: None, 3, (3, ))

        with self.assertRaisesRegexp(TypeError, "same sequence type"):
            nest.map_structure(lambda x, y: None, ((3, 4), 5), {
                "a": (3, 4),
                "b": 5
            })

        with self.assertRaisesRegexp(ValueError, "same nested structure"):
            nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

        with self.assertRaisesRegexp(ValueError, "same nested structure"):
            nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
                               check_types=False)

        with self.assertRaisesRegexp(ValueError,
                                     "Only valid keyword argument"):
            nest.map_structure(lambda x: None, structure1, foo="a")

        with self.assertRaisesRegexp(ValueError,
                                     "Only valid keyword argument"):
            nest.map_structure(lambda x: None,
                               structure1,
                               check_types=False,
                               foo="a")
Exemplo n.º 14
0
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)

      if nest.is_sequence(nested_args):
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset.make_dataset_resource()
Exemplo n.º 15
0
  def testMapStructure(self):
    structure1 = (((1, 2), 3), 4, (5, 6))
    structure2 = (((7, 8), 9), 10, (11, 12))
    structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
    nest.assert_same_structure(structure1, structure1_plus1)
    self.assertAllEqual(
        [2, 3, 4, 5, 6, 7],
        nest.flatten(structure1_plus1))
    structure1_plus_structure2 = nest.map_structure(
        lambda x, y: x + y, structure1, structure2)
    self.assertEqual(
        (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
        structure1_plus_structure2)

    self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

    self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

    with self.assertRaisesRegexp(TypeError, "callable"):
      nest.map_structure("bad", structure1_plus1)

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, 3, (3,))

    with self.assertRaisesRegexp(TypeError, "same sequence type"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
                         check_types=False)

    with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
      nest.map_structure(lambda x: None, structure1, foo="a")

    with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
      nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
    def __init__(self,
                 cell,
                 attention_mechanism,
                 is_manual_attention,
                 manual_alignments,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name=None):
        """Construct the `AttentionWrapper`.
        Args:
            cell: An instance of `RNNCell`.
            attention_mechanism: A list of `AttentionMechanism` instances or a single
                instance.
            attention_layer_size: A list of Python integers or a single Python
                integer, the depth of the attention (output) layer(s). If None
                (default), use the context as attention at each time step. Otherwise,
                feed the context and cell output into the attention layer to generate
                attention at each time step. If attention_mechanism is a list,
                attention_layer_size must be a list of the same length.
            alignment_history: Python boolean, whether to store alignment history
                from all time steps in the final output state (currently stored as a
                time major `TensorArray` on which you must call `stack()`).
            cell_input_fn: (optional) A `callable`.    The default is:
                `lambda inputs, attention: array_tf.concat([inputs, attention], -1)`.
            output_attention: Python bool.    If `True` (default), the output at each
                time step is the attention value.    This is the behavior of Luong-style
                attention mechanisms.    If `False`, the output at each time step is
                the output of `cell`.    This is the beahvior of Bhadanau-style
                attention mechanisms.    In both cases, the `attention` tensor is
                propagated to the next time step via the state and is used there.
                This flag only controls whether the attention mechanism is propagated
                up to the next cell in an RNN stack or to the top RNN output.
            initial_cell_state: The initial state value to use for the cell when
                the user calls `zero_state()`.    Note that if this value is provided
                now, and the user uses a `batch_size` argument of `zero_state` which
                does not match the batch size of `initial_cell_state`, proper
                behavior is not guaranteed.
            name: Name to use when creating tf.
        Raises:
            TypeError: `attention_layer_size` is not None and (`attention_mechanism`
                is a list but `attention_layer_size` is not; or vice versa).
            ValueError: if `attention_layer_size` is not None, `attention_mechanism`
                is a list, and its length does not match that of `attention_layer_size`.
        """
        super(AttentionWrapper, self).__init__(name=name)

        self.is_manual_attention = is_manual_attention
        self.manual_alignments = manual_alignments

        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True
            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        "attention_mechanism must contain only instances of "
                        "AttentionMechanism, saw type: %s" %
                        type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, AttentionMechanism):
                raise TypeError(
                    "attention_mechanism must be an AttentionMechanism or list of "
                    "multiple AttentionMechanism instances, saw type: %s" %
                    type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism, )

        if cell_input_fn is None:
            cell_input_fn = (
                lambda inputs, attention: tf.concat([inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    "cell_input_fn must be callable, saw type: %s" %
                    type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(
                attention_layer_size, (list,
                                       tuple)) else (attention_layer_size, ))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError(
                    "If provided, attention_layer_size must contain exactly one "
                    "integer per attention_mechanism, saw: %d vs %d" %
                    (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(
                layers_core.Dense(attention_layer_size,
                                  name="attention_layer",
                                  use_bias=False)
                for attention_layer_size in attention_layer_sizes)
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_layer_size = sum(
                attention_mechanism.values.get_shape()[-1].value
                for attention_mechanism in attention_mechanisms)

        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with tf.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value
                                    or tf.shape(final_state_tensor)[0])
                error_message = (
                    "When constructing AttentionWrapper %s: " % self._base_name
                    + "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.    Are you using "
                    "the BeamSearchDecoder?    You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with tf.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: tf.identity(s,
                                              name="check_initial_cell_state"),
                        initial_cell_state)
Exemplo n.º 17
0
    def testFlattenUpTo(self):
        input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
        shallow_tree = ((True, True), (False, True))
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9),
                                                (5, 5)])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
        shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ("input_tree_0", "input_tree_1")
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = (0, )
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = (0, 1)
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ("shallow_tree", )
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'str'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        input_tree = "input_tree"
        shallow_tree = ("shallow_tree_9", "shallow_tree_8")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = (9, )
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'int'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        input_tree = 0
        shallow_tree = (9, 8)
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        # Using dict.
        input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
        shallow_tree = {"a": (True, True), "b": (False, True)}
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9),
                                                (5, 5)])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])
Exemplo n.º 18
0
            self._attention_layers = None
            self._attention_layer_size = sum(
                attention_mechanism.values.get_shape()[-1].value
                for attention_mechanism in attention_mechanisms)

        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history

        with tf.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (
                    final_state_tensor.shape[0].value
                    or tf.shape(final_state_tensor)[0])
                error_message = (
                    "When contructing AttentionWrapper %s: " % self._base_name +
                    "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.       Are you using "
                    "the BeamSearchDecoder?     You may need to tile your initial state "
                    "via the tf.contrib.seq2seq.tile_batch function with argument "
                    "multiple=beam_width.")
                with tf.control_dependencies(
                    self._batch_size_checks(state_batch_size, error_message)):
                self._initial_cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="check_initial_cell_state"),
                    initial_cell_state)
  def testNestedStructure(self):
    components = (np.array([1, 2, 3]), (np.array([4., 5.]), np.array([6., 7.])),
                  np.array([8, 9, 10]))

    dataset = dataset_ops.Dataset.from_tensors(components)
    self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
                       dtypes.int64), dataset.output_types)
    self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

    dataset = dataset.shuffle(10, 10)
    self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
                       dtypes.int64), dataset.output_types)
    self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

    dataset = dataset.repeat(-1)
    self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
                       dtypes.int64), dataset.output_types)
    self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

    dataset = dataset.filter(lambda x, y, z: True)
    self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
                       dtypes.int64), dataset.output_types)
    self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

    dataset = dataset.take(5)
    self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
                       dtypes.int64), dataset.output_types)
    self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

    dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
    self.assertEquals(((dtypes.int64, dtypes.int64),
                       (dtypes.float64, dtypes.float64)), dataset.output_types)
    self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)

    dataset = dataset.flat_map(
        lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]),
                                                       (y[0], y[1])))
    )
    self.assertEquals(((dtypes.int64, dtypes.int64),
                       (dtypes.float64, dtypes.float64)), dataset.output_types)
    self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)

    dataset = dataset.batch(32)
    self.assertEquals(((dtypes.int64, dtypes.int64),
                       (dtypes.float64, dtypes.float64)), dataset.output_types)
    self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])),
                      nest.pack_sequence_as(dataset.output_shapes, [
                          s.as_list()
                          for s in nest.flatten(dataset.output_shapes)
                      ]))

    iterator = dataset.make_one_shot_iterator()
    (w, x), (y, z) = iterator.get_next()
    self.assertEquals(dtypes.int64, w.dtype)
    self.assertEquals(dtypes.int64, x.dtype)
    self.assertEquals(dtypes.float64, y.dtype)
    self.assertEquals(dtypes.float64, z.dtype)
    self.assertEquals([None, 3], w.shape.as_list())
    self.assertEquals([None, 3], x.shape.as_list())
    self.assertEquals([None, 2], y.shape.as_list())
    self.assertEquals([None, 2], z.shape.as_list())

    iterator = dataset.make_initializable_iterator()
    (w, x), (y, z) = iterator.get_next()
    self.assertEquals(dtypes.int64, w.dtype)
    self.assertEquals(dtypes.int64, x.dtype)
    self.assertEquals(dtypes.float64, y.dtype)
    self.assertEquals(dtypes.float64, z.dtype)
    self.assertEquals([None, 3], w.shape.as_list())
    self.assertEquals([None, 3], x.shape.as_list())
    self.assertEquals([None, 2], y.shape.as_list())
    self.assertEquals([None, 2], z.shape.as_list())

    # Define a separate set of components with matching leading
    # dimension for the from-slices constructor.
    components_for_slices = (np.array([1, 2, 3]), (np.array(
        [4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12]))

    dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
    self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
                       dtypes.int64), dataset.output_types)
    self.assertEquals(([], ([], []), []), dataset.output_shapes)
Exemplo n.º 20
0
  def testFlattenUpTo(self):
    input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
    shallow_tree = ((True, True), (False, True))
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
    self.assertEqual(flattened_shallow_tree, [True, True, False, True])

    input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
    shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    input_tree_flattened = nest.flatten(input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
    self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

    ## Shallow non-list edge-case.
    # Using iterable elements.
    input_tree = ["input_tree"]
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = ("input_tree_0", "input_tree_1")
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = (0,)
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = (0, 1)
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Both non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = 0
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Input non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = ("shallow_tree",)
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'str'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    input_tree = "input_tree"
    shallow_tree = ("shallow_tree_9", "shallow_tree_8")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = (9,)
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'int'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    input_tree = 0
    shallow_tree = (9, 8)
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    # Using dict.
    input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
    shallow_tree = {"a": (True, True), "b": (False, True)}
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
    self.assertEqual(flattened_shallow_tree, [True, True, False, True])
    def testNestedStructure(self):
        components = (np.array([1, 2,
                                3]), (np.array([4., 5.]), np.array([6., 7.])),
                      np.array([8, 9, 10]))

        dataset = dataset_ops.Dataset.from_tensors(components)
        self.assertEquals(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset.output_types)
        self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

        dataset = dataset.shuffle(10, 10)
        self.assertEquals(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset.output_types)
        self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

        dataset = dataset.repeat(-1)
        self.assertEquals(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset.output_types)
        self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

        dataset = dataset.filter(lambda x, y, z: True)
        self.assertEquals(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset.output_types)
        self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

        dataset = dataset.take(5)
        self.assertEquals(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset.output_types)
        self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)

        dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
        self.assertEquals(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset.output_types)
        self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)

        dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset.
                                   from_tensors(((x[0], x[1]), (y[0], y[1]))))
        self.assertEquals(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset.output_types)
        self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)

        dataset = dataset.batch(32)
        self.assertEquals(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset.output_types)
        self.assertEquals(
            (([None, 3], [None, 3]), ([None, 2], [None, 2])),
            nest.pack_sequence_as(
                dataset.output_shapes,
                [s.as_list() for s in nest.flatten(dataset.output_shapes)]))

        iterator = dataset.make_one_shot_iterator()
        (w, x), (y, z) = iterator.get_next()
        self.assertEquals(dtypes.int64, w.dtype)
        self.assertEquals(dtypes.int64, x.dtype)
        self.assertEquals(dtypes.float64, y.dtype)
        self.assertEquals(dtypes.float64, z.dtype)
        self.assertEquals([None, 3], w.shape.as_list())
        self.assertEquals([None, 3], x.shape.as_list())
        self.assertEquals([None, 2], y.shape.as_list())
        self.assertEquals([None, 2], z.shape.as_list())

        iterator = dataset.make_initializable_iterator()
        (w, x), (y, z) = iterator.get_next()
        self.assertEquals(dtypes.int64, w.dtype)
        self.assertEquals(dtypes.int64, x.dtype)
        self.assertEquals(dtypes.float64, y.dtype)
        self.assertEquals(dtypes.float64, z.dtype)
        self.assertEquals([None, 3], w.shape.as_list())
        self.assertEquals([None, 3], x.shape.as_list())
        self.assertEquals([None, 2], y.shape.as_list())
        self.assertEquals([None, 2], z.shape.as_list())

        # Define a separate set of components with matching leading
        # dimension for the from-slices constructor.
        components_for_slices = (np.array([1, 2, 3]), (np.array([4., 5., 6.]),
                                                       np.array([7., 8., 9.])),
                                 np.array([10, 11, 12]))

        dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
        self.assertEquals(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset.output_types)
        self.assertEquals(([], ([], []), []), dataset.output_shapes)