Example #1
0
    def testGetTraverseShallowStructure(self):
        scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7, )}, []]
        scalar_traverse_r = nest.get_traverse_shallow_structure(
            lambda s: not isinstance(s, tuple), scalar_traverse_input)
        self.assertEqual(scalar_traverse_r,
                         [True, True, False, [True, True], {
                             "a": False
                         }, []])
        nest.assert_shallow_structure(scalar_traverse_r, scalar_traverse_input)

        structure_traverse_input = [(1, [2]), ([1], 2)]
        structure_traverse_r = nest.get_traverse_shallow_structure(
            lambda s: (True, False)
            if isinstance(s, tuple) else True, structure_traverse_input)
        self.assertEqual(structure_traverse_r, [(True, False),
                                                ([True], False)])
        nest.assert_shallow_structure(structure_traverse_r,
                                      structure_traverse_input)

        with self.assertRaisesRegexp(TypeError, "returned structure"):
            nest.get_traverse_shallow_structure(lambda _: [True], 0)

        with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
            nest.get_traverse_shallow_structure(lambda _: 1, [1])

        with self.assertRaisesRegexp(
                TypeError, "didn't return a depth=1 structure of bools"):
            nest.get_traverse_shallow_structure(lambda _: [1], [1])
Example #2
0
  def testGetTraverseShallowStructure(self):
    scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
    scalar_traverse_r = nest.get_traverse_shallow_structure(
        lambda s: not isinstance(s, tuple),
        scalar_traverse_input)
    self.assertEqual(scalar_traverse_r,
                     [True, True, False, [True, True], {"a": False}, []])
    nest.assert_shallow_structure(scalar_traverse_r,
                                  scalar_traverse_input)

    structure_traverse_input = [(1, [2]), ([1], 2)]
    structure_traverse_r = nest.get_traverse_shallow_structure(
        lambda s: (True, False) if isinstance(s, tuple) else True,
        structure_traverse_input)
    self.assertEqual(structure_traverse_r,
                     [(True, False), ([True], False)])
    nest.assert_shallow_structure(structure_traverse_r,
                                  structure_traverse_input)

    with self.assertRaisesRegexp(TypeError, "returned structure"):
      nest.get_traverse_shallow_structure(lambda _: [True], 0)

    with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
      nest.get_traverse_shallow_structure(lambda _: 1, [1])

    with self.assertRaisesRegexp(
        TypeError, "didn't return a depth=1 structure of bools"):
      nest.get_traverse_shallow_structure(lambda _: [1], [1])
Example #3
0
    def __call__(self, inputs, state, scope=None):
        """Run the cell with the declared dropouts."""
        def _should_dropout(p):
            return (not isinstance(p, float)) or p < 1

        if _should_dropout(self._input_keep_prob):

            inputs = tf.cond(
                self.is_train, lambda: self._dropout(
                    inputs, "input", self._recurrent_input_noise, self.
                    _input_keep_prob), lambda: inputs * self._input_keep_prob)

        output, new_state = self._cell(inputs, state, scope)
        if _should_dropout(self._state_keep_prob):
            # Identify which subsets of the state to perform dropout on and
            # which ones to keep.

            shallow_filtered_substructure = nest.get_traverse_shallow_structure(
                self._dropout_state_filter, new_state)

            new_state = tf.cond(
                self.is_train, lambda: self._dropout(
                    new_state, "state", self._recurrent_state_noise, self.
                    _state_keep_prob, shallow_filtered_substructure),
                lambda: new_state * self._state_keep_prob)

        if _should_dropout(self._output_keep_prob):
            output = tf.cond(
                self.is_train, lambda: self._dropout(
                    output, "output", self._recurrent_output_noise, self.
                    _output_keep_prob),
                lambda: output * self._output_keep_prob)

        return output, new_state
Example #4
0
def _get_event_shape_shallow_structure(event_shape):
    """Gets shallow structure, treating lists of ints at the leaves as atomic."""
    def _not_list_of_ints(s):
        if isinstance(s, list) or isinstance(s, tuple):
            return not all(isinstance(x, int) for x in s)
        return True

    return nest.get_traverse_shallow_structure(_not_list_of_ints, event_shape)
  def testNestGetTraverseShallowStructure(self):
    func = lambda t: not (isinstance(t, CT) and t.metadata == 'B')
    structure = [CT([1, 2], metadata='A'), CT([CT(3)], metadata='B')]

    result = nest.get_traverse_shallow_structure(
        func, structure, expand_composites=True)
    expected = [CT([True, True], metadata='A'), False]
    self.assertEqual(result, expected)
Example #6
0
  def testNestGetTraverseShallowStructure(self):
    func = lambda t: not (isinstance(t, CT) and t.metadata == 'B')
    structure = [CT([1, 2], metadata='A'), CT([CT(3)], metadata='B')]

    result = nest.get_traverse_shallow_structure(
        func, structure, expand_composites=True)
    expected = [CT([True, True], metadata='A'), False]
    self.assertEqual(result, expected)
Example #7
0
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            independent_joint_distribution_from_structure,
            structure_of_distributions)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(
            structure_of_distributions, validate_args=validate_args)
    return joint_distribution_sequential.JointDistributionSequential(
        structure_of_distributions, validate_args=validate_args)
Example #8
0
    def __call__(self, inputs, state, scope=None):
        """Run the cell with the declared dropouts."""

        # store phase value

        if(len(inputs.shape)>1):
            inputs, self.phase = array_ops.split(inputs, [inputs.shape[1].value -1, 1], axis=1)
            reap = 1
            print(inputs.shape)
            print(self.phase.shape)
            # print(phase)
            # phase = inputs[:,-1]
            # inputs = inputs[:,:-1]
        else:
            inputs, self.phase = array_ops.split(inputs, [inputs.shape[0].value - 1, 1], axis=0)
            reap = 0
            # phase = inputs[-1]
            # inputs = inputs[:-1]

        def _should_dropout(p):
             return (not isinstance(p, float)) or p < 1

        if _should_dropout(self._input_keep_prob):
            inputs = self._dropout(inputs, "input",
                                 self._recurrent_input_noise,
                                 self._input_keep_prob)

        # re-append phase so PFNN can use it

        inputs = array_ops.concat([inputs, self.phase], reap)

        output, new_state = self._cell(inputs, state)

        if _should_dropout(self._state_keep_prob):
        #       Identify which subsets of the state to perform dropout on and
          # which ones to keep.
            shallow_filtered_substructure = nest.get_traverse_shallow_structure(
                                    self._dropout_state_filter, new_state)
            new_state = self._dropout(new_state, "state",
                                    self._recurrent_state_noise,
                                    self._state_keep_prob,
                                    shallow_filtered_substructure)
        if _should_dropout(self._output_keep_prob):
            output = self._dropout(output, "output",
                                 self._recurrent_output_noise,
                                 self._output_keep_prob)
        return output, new_state
Example #9
0
    def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
        """Runs the wrapped cell and applies dropout.

    Args:
      inputs: A tensor with wrapped cell's input.
      state: A tensor or tuple of tensors with wrapped cell's state.
      cell_call_fn: Wrapped cell's method to use for step computation (cell's
        `__call__` or 'call' method).
      **kwargs: Additional arguments.

    Returns:
      A pair containing:

      - Output: A tensor with cell's output.
      - New state: A tensor or tuple of tensors with new wrapped cell's state.
    """
        def _should_dropout(p):
            return (not isinstance(p, float)) or p < 1

        if _should_dropout(self._input_keep_prob):
            inputs = self._dropout(inputs, "input",
                                   self._recurrent_input_noise,
                                   self._input_keep_prob)
        output, new_state = cell_call_fn(inputs, state, **kwargs)
        if _should_dropout(self._state_keep_prob):
            # Identify which subsets of the state to perform dropout on and
            # which ones to keep.
            shallow_filtered_substructure = nest.get_traverse_shallow_structure(
                self._dropout_state_filter, new_state)
            new_state = self._dropout(new_state, "state",
                                      self._recurrent_state_noise,
                                      self._state_keep_prob,
                                      shallow_filtered_substructure)
        if _should_dropout(self._output_keep_prob):
            output = self._dropout(output, "output",
                                   self._recurrent_output_noise,
                                   self._output_keep_prob)
        return output, new_state
Example #10
0
  def __call__(self, inputs, state, scope=None):
    """Run the cell with the declared dropouts."""
    def _should_dropout(p):
      return (not isinstance(p, float)) or p < 1

    if _should_dropout(self._input_keep_prob):
      inputs = self._dropout(inputs, "input",
                             self._recurrent_input_noise,
                             self._input_keep_prob)
    output, new_state = self._cell(inputs, state, scope=scope)
    if _should_dropout(self._state_keep_prob):
      # Identify which subsets of the state to perform dropout on and
      # which ones to keep.
      shallow_filtered_substructure = nest.get_traverse_shallow_structure(
          self._dropout_state_filter, new_state)
      new_state = self._dropout(new_state, "state",
                                self._recurrent_state_noise,
                                self._state_keep_prob,
                                shallow_filtered_substructure)
    if _should_dropout(self._output_keep_prob):
      output = self._dropout(output, "output",
                             self._recurrent_output_noise,
                             self._output_keep_prob)
    return output, new_state
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  batch_ndims=None,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions
      shared across all members of the input structure. If this is specified,
      the returned joint distribution will be an autobatched distribution with
      the given batch rank, and all other dimensions absorbed into the event.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        dist = structure_of_distributions
        if batch_ndims is not None:
            excess_ndims = ps.rank_from_shape(
                dist.batch_shape_tensor()) - batch_ndims
            if tf.get_static_value(
                    excess_ndims) != 0:  # Static value may be None.
                dist = independent.Independent(
                    dist, reinterpreted_batch_ndims=excess_ndims)
        return dist

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            functools.partial(independent_joint_distribution_from_structure,
                              batch_ndims=batch_ndims,
                              validate_args=validate_args),
            structure_of_distributions)

    jdnamed = joint_distribution_named.JointDistributionNamed
    jdsequential = joint_distribution_sequential.JointDistributionSequential
    # Use an autobatched JD if a specific batch rank was requested.
    if batch_ndims is not None:
        jdnamed = functools.partial(
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)
        jdsequential = functools.partial(
            joint_distribution_auto_batched.
            JointDistributionSequentialAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict') or isinstance(
            structure_of_distributions, collections.abc.Mapping)):
        return jdnamed(structure_of_distributions, validate_args=validate_args)
    return jdsequential(structure_of_distributions,
                        validate_args=validate_args)
Example #12
0
def _get_shallow_structure(struct):
    # Get a shallow version of struct where the children are replaced by
    # 'False'.
    return nest.get_traverse_shallow_structure(lambda s: s is struct, struct)
def get_shallow_tree(is_leaf, tree):
    """Returns a shallow tree, expanding only when is_leaf(subtree) is False."""
    return nest.get_traverse_shallow_structure(lambda t: not is_leaf(t), tree)
Example #14
0
    def __call__(self, inputs, state, scope=None):
        """Run the cell with the declared dropouts."""

        print('SkipDropoutWrapper_cell', self._cell)
        print('SkipDropoutWrapper_inputs', inputs)
        print('SkipDropoutWrapper_state', state)

        if isinstance(state, list):
            state = tuple(state)

        def _should_dropout(p):
            return (not isinstance(p, float)) or p < 1

        if _should_dropout(self._input_keep_prob):
            #TODO is only needed if multiple SkipCells are used
            #rebuild = False
            #if isinstance(inputs, SkipLSTMOutputTuple):
            #    inputs, state_gate = inputs
            #    rebuild = True
            print('SkipDropoutWrapper_inputs_sot', inputs)
            inputs = self._dropout(inputs, "input",
                                   self._recurrent_input_noise,
                                   self._input_keep_prob)
        output, new_state = self._cell(inputs, state, scope=scope)
        #if rebuild:
        #    output, new_state_gate = output
        #    output = SkipLSTMOutputTuple(output, new_state_gate)

        # Separating SkipState and using the LSTMStateTuple as new_state for Dropout
        if isinstance(self._cell, MultiSkipLSTMCell):
            _, _, up, cup = new_state[-1]
            new_state = [LSTMStateTuple(s.c, s.h) for s in state]
        elif isinstance(self._cell, SkipLSTMCell):
            c, h, up, cup = new_state
            new_state = LSTMStateTuple(c, h)
        print('SkipDropoutWrapper_output', output)
        print('SkipDropoutWrapper_new_state', new_state)

        if isinstance(self._cell, SkipLSTMCell):
            output, state_gate = output

        if _should_dropout(self._state_keep_prob):
            # Identify which subsets of the state to perform dropout on and
            # which ones to keep.
            shallow_filtered_substructure = nest.get_traverse_shallow_structure(
                self._dropout_state_filter, new_state)
            new_state = self._dropout(new_state, "state",
                                      self._recurrent_state_noise,
                                      self._state_keep_prob,
                                      shallow_filtered_substructure)
            print('SkipDropoutWrapper_new_state', new_state)
        if _should_dropout(self._output_keep_prob):
            output = self._dropout(output, "output",
                                   self._recurrent_output_noise,
                                   self._output_keep_prob)
            print('SkipDropoutWrapper_new_output', output)
        if isinstance(self._cell, MultiSkipLSTMCell):
            final_state = SkipLSTMStateTuple(new_state[-1].c, new_state[-1].h,
                                             up, cup)
            new_state[-1] = final_state
            output = SkipLSTMOutputTuple(output, state_gate)
        elif isinstance(self._cell, SkipLSTMCell):
            new_state = SkipLSTMStateTuple(new_state.c, new_state.h, up, cup)
            output = SkipLSTMOutputTuple(output, state_gate)
        print('SkipDropoutWrapper_new_output', output)
        print('SkipDropoutWrapper_new_state', new_state)
        return output, new_state