def testGetTraverseShallowStructure(self): scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7, )}, []] scalar_traverse_r = nest.get_traverse_shallow_structure( lambda s: not isinstance(s, tuple), scalar_traverse_input) self.assertEqual(scalar_traverse_r, [True, True, False, [True, True], { "a": False }, []]) nest.assert_shallow_structure(scalar_traverse_r, scalar_traverse_input) structure_traverse_input = [(1, [2]), ([1], 2)] structure_traverse_r = nest.get_traverse_shallow_structure( lambda s: (True, False) if isinstance(s, tuple) else True, structure_traverse_input) self.assertEqual(structure_traverse_r, [(True, False), ([True], False)]) nest.assert_shallow_structure(structure_traverse_r, structure_traverse_input) with self.assertRaisesRegexp(TypeError, "returned structure"): nest.get_traverse_shallow_structure(lambda _: [True], 0) with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): nest.get_traverse_shallow_structure(lambda _: 1, [1]) with self.assertRaisesRegexp( TypeError, "didn't return a depth=1 structure of bools"): nest.get_traverse_shallow_structure(lambda _: [1], [1])
def testGetTraverseShallowStructure(self): scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] scalar_traverse_r = nest.get_traverse_shallow_structure( lambda s: not isinstance(s, tuple), scalar_traverse_input) self.assertEqual(scalar_traverse_r, [True, True, False, [True, True], {"a": False}, []]) nest.assert_shallow_structure(scalar_traverse_r, scalar_traverse_input) structure_traverse_input = [(1, [2]), ([1], 2)] structure_traverse_r = nest.get_traverse_shallow_structure( lambda s: (True, False) if isinstance(s, tuple) else True, structure_traverse_input) self.assertEqual(structure_traverse_r, [(True, False), ([True], False)]) nest.assert_shallow_structure(structure_traverse_r, structure_traverse_input) with self.assertRaisesRegexp(TypeError, "returned structure"): nest.get_traverse_shallow_structure(lambda _: [True], 0) with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): nest.get_traverse_shallow_structure(lambda _: 1, [1]) with self.assertRaisesRegexp( TypeError, "didn't return a depth=1 structure of bools"): nest.get_traverse_shallow_structure(lambda _: [1], [1])
def __call__(self, inputs, state, scope=None): """Run the cell with the declared dropouts.""" def _should_dropout(p): return (not isinstance(p, float)) or p < 1 if _should_dropout(self._input_keep_prob): inputs = tf.cond( self.is_train, lambda: self._dropout( inputs, "input", self._recurrent_input_noise, self. _input_keep_prob), lambda: inputs * self._input_keep_prob) output, new_state = self._cell(inputs, state, scope) if _should_dropout(self._state_keep_prob): # Identify which subsets of the state to perform dropout on and # which ones to keep. shallow_filtered_substructure = nest.get_traverse_shallow_structure( self._dropout_state_filter, new_state) new_state = tf.cond( self.is_train, lambda: self._dropout( new_state, "state", self._recurrent_state_noise, self. _state_keep_prob, shallow_filtered_substructure), lambda: new_state * self._state_keep_prob) if _should_dropout(self._output_keep_prob): output = tf.cond( self.is_train, lambda: self._dropout( output, "output", self._recurrent_output_noise, self. _output_keep_prob), lambda: output * self._output_keep_prob) return output, new_state
def _get_event_shape_shallow_structure(event_shape): """Gets shallow structure, treating lists of ints at the leaves as atomic.""" def _not_list_of_ints(s): if isinstance(s, list) or isinstance(s, tuple): return not all(isinstance(x, int) for x in s) return True return nest.get_traverse_shallow_structure(_not_list_of_ints, event_shape)
def testNestGetTraverseShallowStructure(self): func = lambda t: not (isinstance(t, CT) and t.metadata == 'B') structure = [CT([1, 2], metadata='A'), CT([CT(3)], metadata='B')] result = nest.get_traverse_shallow_structure( func, structure, expand_composites=True) expected = [CT([True, True], metadata='A'), False] self.assertEqual(result, expected)
def independent_joint_distribution_from_structure(structure_of_distributions, validate_args=False): """Turns a (potentially nested) structure of dists into a single dist. Args: structure_of_distributions: instance of `tfd.Distribution`, or nested structure (tuple, list, dict, etc.) in which all leaves are `tfd.Distribution` instances. validate_args: Python `bool`. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. Returns: distribution: instance of `tfd.Distribution` such that `distribution.sample()` is equivalent to `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`. If `structure_of_distributions` was indeed a structure (as opposed to a single `Distribution` instance), this will be a `JointDistribution` with the corresponding structure. Raises: TypeError: if any leaves of the input structure are not `tfd.Distribution` instances. """ # If input is already a Distribution, just return it. if dist_util.is_distribution_instance(structure_of_distributions): return structure_of_distributions # If this structure contains other structures (ie, has elements at depth > 1), # recursively turn them into JDs. element_depths = nest.map_structure_with_tuple_paths( lambda path, x: len(path), structure_of_distributions) if max(tf.nest.flatten(element_depths)) > 1: next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, independent_joint_distribution_from_structure, structure_of_distributions) # Otherwise, build a JD from the current structure. if (hasattr(structure_of_distributions, '_asdict') or isinstance(structure_of_distributions, collections.Mapping)): return joint_distribution_named.JointDistributionNamed( structure_of_distributions, validate_args=validate_args) return joint_distribution_sequential.JointDistributionSequential( structure_of_distributions, validate_args=validate_args)
def __call__(self, inputs, state, scope=None): """Run the cell with the declared dropouts.""" # store phase value if(len(inputs.shape)>1): inputs, self.phase = array_ops.split(inputs, [inputs.shape[1].value -1, 1], axis=1) reap = 1 print(inputs.shape) print(self.phase.shape) # print(phase) # phase = inputs[:,-1] # inputs = inputs[:,:-1] else: inputs, self.phase = array_ops.split(inputs, [inputs.shape[0].value - 1, 1], axis=0) reap = 0 # phase = inputs[-1] # inputs = inputs[:-1] def _should_dropout(p): return (not isinstance(p, float)) or p < 1 if _should_dropout(self._input_keep_prob): inputs = self._dropout(inputs, "input", self._recurrent_input_noise, self._input_keep_prob) # re-append phase so PFNN can use it inputs = array_ops.concat([inputs, self.phase], reap) output, new_state = self._cell(inputs, state) if _should_dropout(self._state_keep_prob): # Identify which subsets of the state to perform dropout on and # which ones to keep. shallow_filtered_substructure = nest.get_traverse_shallow_structure( self._dropout_state_filter, new_state) new_state = self._dropout(new_state, "state", self._recurrent_state_noise, self._state_keep_prob, shallow_filtered_substructure) if _should_dropout(self._output_keep_prob): output = self._dropout(output, "output", self._recurrent_output_noise, self._output_keep_prob) return output, new_state
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): """Runs the wrapped cell and applies dropout. Args: inputs: A tensor with wrapped cell's input. state: A tensor or tuple of tensors with wrapped cell's state. cell_call_fn: Wrapped cell's method to use for step computation (cell's `__call__` or 'call' method). **kwargs: Additional arguments. Returns: A pair containing: - Output: A tensor with cell's output. - New state: A tensor or tuple of tensors with new wrapped cell's state. """ def _should_dropout(p): return (not isinstance(p, float)) or p < 1 if _should_dropout(self._input_keep_prob): inputs = self._dropout(inputs, "input", self._recurrent_input_noise, self._input_keep_prob) output, new_state = cell_call_fn(inputs, state, **kwargs) if _should_dropout(self._state_keep_prob): # Identify which subsets of the state to perform dropout on and # which ones to keep. shallow_filtered_substructure = nest.get_traverse_shallow_structure( self._dropout_state_filter, new_state) new_state = self._dropout(new_state, "state", self._recurrent_state_noise, self._state_keep_prob, shallow_filtered_substructure) if _should_dropout(self._output_keep_prob): output = self._dropout(output, "output", self._recurrent_output_noise, self._output_keep_prob) return output, new_state
def __call__(self, inputs, state, scope=None): """Run the cell with the declared dropouts.""" def _should_dropout(p): return (not isinstance(p, float)) or p < 1 if _should_dropout(self._input_keep_prob): inputs = self._dropout(inputs, "input", self._recurrent_input_noise, self._input_keep_prob) output, new_state = self._cell(inputs, state, scope=scope) if _should_dropout(self._state_keep_prob): # Identify which subsets of the state to perform dropout on and # which ones to keep. shallow_filtered_substructure = nest.get_traverse_shallow_structure( self._dropout_state_filter, new_state) new_state = self._dropout(new_state, "state", self._recurrent_state_noise, self._state_keep_prob, shallow_filtered_substructure) if _should_dropout(self._output_keep_prob): output = self._dropout(output, "output", self._recurrent_output_noise, self._output_keep_prob) return output, new_state
def independent_joint_distribution_from_structure(structure_of_distributions, batch_ndims=None, validate_args=False): """Turns a (potentially nested) structure of dists into a single dist. Args: structure_of_distributions: instance of `tfd.Distribution`, or nested structure (tuple, list, dict, etc.) in which all leaves are `tfd.Distribution` instances. batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions shared across all members of the input structure. If this is specified, the returned joint distribution will be an autobatched distribution with the given batch rank, and all other dimensions absorbed into the event. validate_args: Python `bool`. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. Returns: distribution: instance of `tfd.Distribution` such that `distribution.sample()` is equivalent to `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`. If `structure_of_distributions` was indeed a structure (as opposed to a single `Distribution` instance), this will be a `JointDistribution` with the corresponding structure. Raises: TypeError: if any leaves of the input structure are not `tfd.Distribution` instances. """ # If input is already a Distribution, just return it. if dist_util.is_distribution_instance(structure_of_distributions): dist = structure_of_distributions if batch_ndims is not None: excess_ndims = ps.rank_from_shape( dist.batch_shape_tensor()) - batch_ndims if tf.get_static_value( excess_ndims) != 0: # Static value may be None. dist = independent.Independent( dist, reinterpreted_batch_ndims=excess_ndims) return dist # If this structure contains other structures (ie, has elements at depth > 1), # recursively turn them into JDs. element_depths = nest.map_structure_with_tuple_paths( lambda path, x: len(path), structure_of_distributions) if max(tf.nest.flatten(element_depths)) > 1: next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, functools.partial(independent_joint_distribution_from_structure, batch_ndims=batch_ndims, validate_args=validate_args), structure_of_distributions) jdnamed = joint_distribution_named.JointDistributionNamed jdsequential = joint_distribution_sequential.JointDistributionSequential # Use an autobatched JD if a specific batch rank was requested. if batch_ndims is not None: jdnamed = functools.partial( joint_distribution_auto_batched.JointDistributionNamedAutoBatched, batch_ndims=batch_ndims, use_vectorized_map=False) jdsequential = functools.partial( joint_distribution_auto_batched. JointDistributionSequentialAutoBatched, batch_ndims=batch_ndims, use_vectorized_map=False) # Otherwise, build a JD from the current structure. if (hasattr(structure_of_distributions, '_asdict') or isinstance( structure_of_distributions, collections.abc.Mapping)): return jdnamed(structure_of_distributions, validate_args=validate_args) return jdsequential(structure_of_distributions, validate_args=validate_args)
def _get_shallow_structure(struct): # Get a shallow version of struct where the children are replaced by # 'False'. return nest.get_traverse_shallow_structure(lambda s: s is struct, struct)
def get_shallow_tree(is_leaf, tree): """Returns a shallow tree, expanding only when is_leaf(subtree) is False.""" return nest.get_traverse_shallow_structure(lambda t: not is_leaf(t), tree)
def __call__(self, inputs, state, scope=None): """Run the cell with the declared dropouts.""" print('SkipDropoutWrapper_cell', self._cell) print('SkipDropoutWrapper_inputs', inputs) print('SkipDropoutWrapper_state', state) if isinstance(state, list): state = tuple(state) def _should_dropout(p): return (not isinstance(p, float)) or p < 1 if _should_dropout(self._input_keep_prob): #TODO is only needed if multiple SkipCells are used #rebuild = False #if isinstance(inputs, SkipLSTMOutputTuple): # inputs, state_gate = inputs # rebuild = True print('SkipDropoutWrapper_inputs_sot', inputs) inputs = self._dropout(inputs, "input", self._recurrent_input_noise, self._input_keep_prob) output, new_state = self._cell(inputs, state, scope=scope) #if rebuild: # output, new_state_gate = output # output = SkipLSTMOutputTuple(output, new_state_gate) # Separating SkipState and using the LSTMStateTuple as new_state for Dropout if isinstance(self._cell, MultiSkipLSTMCell): _, _, up, cup = new_state[-1] new_state = [LSTMStateTuple(s.c, s.h) for s in state] elif isinstance(self._cell, SkipLSTMCell): c, h, up, cup = new_state new_state = LSTMStateTuple(c, h) print('SkipDropoutWrapper_output', output) print('SkipDropoutWrapper_new_state', new_state) if isinstance(self._cell, SkipLSTMCell): output, state_gate = output if _should_dropout(self._state_keep_prob): # Identify which subsets of the state to perform dropout on and # which ones to keep. shallow_filtered_substructure = nest.get_traverse_shallow_structure( self._dropout_state_filter, new_state) new_state = self._dropout(new_state, "state", self._recurrent_state_noise, self._state_keep_prob, shallow_filtered_substructure) print('SkipDropoutWrapper_new_state', new_state) if _should_dropout(self._output_keep_prob): output = self._dropout(output, "output", self._recurrent_output_noise, self._output_keep_prob) print('SkipDropoutWrapper_new_output', output) if isinstance(self._cell, MultiSkipLSTMCell): final_state = SkipLSTMStateTuple(new_state[-1].c, new_state[-1].h, up, cup) new_state[-1] = final_state output = SkipLSTMOutputTuple(output, state_gate) elif isinstance(self._cell, SkipLSTMCell): new_state = SkipLSTMStateTuple(new_state.c, new_state.h, up, cup) output = SkipLSTMOutputTuple(output, state_gate) print('SkipDropoutWrapper_new_output', output) print('SkipDropoutWrapper_new_state', new_state) return output, new_state