def _check_same_outputs(true_graph, false_graph): """Raises an error if true_graph and false_graph have different outputs.""" def error(error_detail): raise TypeError( "true_fn and false_fn arguments to tf.cond must have the same number, " "type, and overall structure of return values.\n" "\n" "true_fn output: %s\n" "false_fn output: %s\n" "\n" "Error details:\n" "%s" % (true_graph.structured_outputs, false_graph.structured_outputs, error_detail)) try: nest.assert_same_structure(true_graph.structured_outputs, false_graph.structured_outputs, expand_composites=True) except (ValueError, TypeError) as e: error(str(e)) assert len(true_graph.outputs) == len(false_graph.outputs) for true_out, false_out in zip(true_graph.outputs, false_graph.outputs): if true_out.dtype != false_out.dtype: error("%s and %s have different types" % (true_out, false_out))
def compute(i, a_flat, tas): """The loop body of scan. Args: i: the loop counter. a_flat: the accumulator value(s), flattened. tas: the output accumulator TensorArray(s), flattened. Returns: [i + 1, a_flat, tas]: the updated counter + new accumulator values + updated TensorArrays Raises: TypeError: if initializer and fn() output structure do not match ValueType: if initializer and fn() output lengths do not match """ packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) packed_a = output_pack(a_flat) a_out = fn(packed_a, packed_elems) nest.assert_same_structure( elems if initializer is None else initializer, a_out) flat_a_out = output_flatten(a_out) tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] if reverse: next_i = i - 1 else: next_i = i + 1 return (next_i, flat_a_out, tas)
def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:])
def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs)
def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False try: # Verify that no input elements were dropped during flattening. repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) # TODO(b/129422719): Namedtuple subclasses re-created through # saved_model.load don't compare equal in type to the original in # assert_same_structure. Fix that and we can take out check_types=False # here. nest.assert_same_structure(inputs, repacked, check_types=False) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False else: if arg != expected: return False return True
def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) nest.assert_same_structure(structure1, structure1_plus1) self.assertAllEqual( [2, 3, 4, 5, 6, 7], nest.flatten(structure1_plus1)) structure1_plus_structure2 = nest.map_structure( lambda x, y: x + y, structure1, structure2) self.assertEqual( (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), structure1_plus_structure2) self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) with self.assertRaisesRegexp(TypeError, "callable"): nest.map_structure("bad", structure1_plus1) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, 3, (3,)) with self.assertRaisesRegexp(TypeError, "same sequence type"): nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
def __init__(self, initial_state, mask=None, name="trainable_initial_state"): """Constructs the Module that introduces a trainable state in the graph. It receives an initial state that will be used as the initial values for the trainable variables that the module contains, and optionally a mask that indicates the parts of the initial state that should be learnable. Args: initial_state: tensor or arbitrarily nested iterables of tensors. mask: optional boolean mask. It should have the same nested structure as the given initial_state. name: module name. Raises: TypeError: if mask is not a list of booleans or None. """ super(TrainableInitialState, self).__init__(name=name) # Since python 2.7, DeprecationWarning is ignored by default. # Turn on the warning: warnings.simplefilter("always", DeprecationWarning) warnings.warn("Use the trainable flag in initial_state instead.", DeprecationWarning, stacklevel=2) if mask is not None: flat_mask = nest.flatten(mask) if not all([isinstance(m, bool) for m in flat_mask]): raise TypeError("Mask should be None or a list of boolean values.") nest.assert_same_structure(initial_state, mask) self._mask = mask self._initial_state = initial_state
def body(time, elements_finished, current_input, emit_ta, state, loop_state): """Internal while loop body for raw_rnn. Args: time: time scalar. elements_finished: batch-size vector. current_input: possibly nested tuple of input tensors. emit_ta: possibly nested tuple of output TensorArrays. state: possibly nested tuple of state tensors. loop_state: possibly nested tuple of loop state tensors. Returns: Tuple having the same size as Args but with updated values. """ (next_output, cell_state) = cell(current_input, state) nest.assert_same_structure(state, cell_state) nest.assert_same_structure(cell.output_size, next_output) next_time = time + 1 (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn( next_time, next_output, cell_state, loop_state ) nest.assert_same_structure(state, next_state) nest.assert_same_structure(current_input, next_input) nest.assert_same_structure(emit_ta, emit_output) # If loop_fn returns None for next_loop_state, just reuse the # previous one. loop_state = loop_state if next_loop_state is None else next_loop_state def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" current_flat = nest.flatten(current) candidate_flat = nest.flatten(candidate) # pylint: disable=g-long-lambda,cell-var-from-loop result_flat = [ _on_device( lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device ) for (current_i, candidate_i) in zip(current_flat, candidate_flat) ] # pylint: enable=g-long-lambda,cell-var-from-loop return nest.pack_sequence_as(structure=current, flat_sequence=result_flat) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) emit_output_flat = nest.flatten(emit_output) emit_ta_flat = nest.flatten(emit_ta) elements_finished = math_ops.logical_or(elements_finished, next_finished) emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)] emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat) return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
def _is_flat(sequence): sequence_flat = nest.flatten(sequence) try: nest.assert_same_structure(sequence_flat, sequence) return True except ValueError: return False except TypeError: return False
def testNestAssertSameStructureCompositeMismatch(self, s1, s2, error=ValueError): # s1 and s2 have the same structure if expand_composites=False; but # different structures if expand_composites=True. nest.assert_same_structure(s1, s2, expand_composites=False) nest.assert_shallow_structure(s1, s2, expand_composites=False) with self.assertRaises(error): # pylint: disable=g-error-prone-assert-raises nest.assert_same_structure(s1, s2, expand_composites=True)
def _assert_correct_outputs(self, initial_state_): nest.assert_same_structure(initial_state_, self.decoder_cell.state_size) nest.assert_same_structure(initial_state_, self.encoder_outputs.final_state) encoder_state_flat = nest.flatten(self.encoder_outputs.final_state) with self.test_session() as sess: encoder_state_flat_ = sess.run(encoder_state_flat) initial_state_flat_ = nest.flatten(initial_state_) for e_dec, e_enc in zip(initial_state_flat_, encoder_state_flat_): np.testing.assert_array_equal(e_dec, e_enc)
def insert(self, keys, values): nest.assert_same_structure(self._hash_tables, values) # Avoid race conditions by requiring that all inputs are computed before any # inserts happen (an issue if one key's update relies on another's value). values_flat = [array_ops.identity(value) for value in nest.flatten(values)] with ops.control_dependencies(values_flat): insert_ops = [hash_table.insert(keys, value) for hash_table, value in zip(nest.flatten(self._hash_tables), values_flat)] return control_flow_ops.group(*insert_ops)
def run_and_report(self, s1, s2, name): burn_iter, test_iter = 100, 30000 for _ in xrange(burn_iter): nest.assert_same_structure(s1, s2) t0 = time.time() for _ in xrange(test_iter): nest.assert_same_structure(s1, s2) t1 = time.time() self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, name=name)
def _maybe_copy_some_through(): """Run RNN step. Pass through either no or some past state.""" new_output, new_state = call_cell() nest.assert_same_structure(state, new_state) flat_new_state = nest.flatten(new_state) flat_new_output = nest.flatten(new_output) return control_flow_ops.cond( # if t < min_seq_len: calculate and return everything time < min_sequence_length, lambda: flat_new_output + flat_new_state, # else copy some of it through lambda: _copy_some_through(flat_new_output, flat_new_state))
def body(time, elements_finished, current_input, emit_ta, state, loop_state): """Internal while loop body for raw_rnn. Args: time: time scalar. elements_finished: batch-size vector. current_input: possibly nested tuple of input tensors. emit_ta: possibly nested tuple of output TensorArrays. state: possibly nested tuple of state tensors. loop_state: possibly nested tuple of loop state tensors. Returns: Tuple having the same size as Args but with updated values. """ (next_output, cell_state) = cell(current_input, state) nest.assert_same_structure(state, cell_state) nest.assert_same_structure(cell.output_size, next_output) next_time = time + 1 (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn( next_time, next_output, cell_state, loop_state) nest.assert_same_structure(state, next_state) nest.assert_same_structure(current_input, next_input) nest.assert_same_structure(emit_ta, emit_output) # If loop_fn returns None for next_loop_state, just reuse the # previous one. loop_state = loop_state if next_loop_state is None else next_loop_state def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" def copy_fn(cur_i, cand_i): return _on_device( lambda: array_ops.where(elements_finished, cur_i, cand_i), device=cand_i.op.device) return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) emit_ta = nest.map_structure( lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) elements_finished = math_ops.logical_or(elements_finished, next_finished) return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
def check_mutation(n1, n2): """Check if two list of arguments are exactly the same.""" errmsg = ("Function to be traced should not modify structure of input " "arguments. Check if your function has list and dictionary " "operations that alter input arguments, " "such as `list.pop`, `list.append`") try: nest.assert_same_structure(n1, n2) except ValueError: raise ValueError(errmsg) for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)): if arg1 is not arg2: raise ValueError(errmsg)
def testInitialStateComputation(self, tuple_state, mask): if tuple_state: initial_state = (tf.fill([BATCH_SIZE, 6], 2), (tf.fill([BATCH_SIZE, 7], 3), tf.fill([BATCH_SIZE, 8], 4))) else: initial_state = tf.fill([BATCH_SIZE, 9], 10) trainable_state_module = snt.TrainableInitialState(initial_state, mask=mask) trainable_state = trainable_state_module() flat_trainable_state = nest.flatten(trainable_state) nest.assert_same_structure(initial_state, trainable_state) flat_initial_state = nest.flatten(initial_state) if mask is not None: flat_mask = nest.flatten(mask) else: flat_mask = (True,) * len(flat_initial_state) self.evaluate(tf.global_variables_initializer()) # Check all variables are initialized correctly and return a state that # has the same as it is provided. for trainable_state, initial_state in zip(flat_trainable_state, flat_initial_state): self.assertAllEqual( self.evaluate(trainable_state), self.evaluate(initial_state)) # Change the value of all the trainable variables to ones. for variable in tf.trainable_variables(): self.evaluate(tf.assign(variable, tf.ones_like(variable))) # In eager mode to re-evaluate the module we must re-connect it. trainable_state = trainable_state_module() flat_trainable_state = nest.flatten(trainable_state) # Check that the values of the initial_states have changed if and only if # they are trainable. for trainable_state, initial_state, mask in zip(flat_trainable_state, flat_initial_state, flat_mask): trainable_state_value = self.evaluate(trainable_state) initial_state_value = self.evaluate(initial_state) if mask: expected_value = np.ones_like(initial_state_value) else: expected_value = initial_state_value self.assertAllEqual(trainable_state_value, expected_value)
def test_convert_to_generator_like(self, input_fn, inputs): expected_batches = 5 data = input_fn(self, inputs, expected_batches) # Dataset and Iterator not supported in Legacy Graph mode. if (not context.executing_eagerly() and isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))): return generator, steps = training_generator.convert_to_generator_like( data, batch_size=2, steps_per_epoch=expected_batches) self.assertEqual(steps, expected_batches) for _ in range(expected_batches): outputs = next(generator) nest.assert_same_structure(outputs, inputs)
def testNestAssertSameStructure(self): st1 = sparse_tensor.SparseTensor([[0]], [0], [100]) st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100]) test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape) nest.assert_same_structure(st1, st2, expand_composites=False) nest.assert_same_structure(st1, st2, expand_composites=True) nest.assert_same_structure(st1, test, expand_composites=False) with self.assertRaises(TypeError): nest.assert_same_structure(st1, test, expand_composites=True)
def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) nest.assert_same_structure(structure1, structure1_plus1) self.assertAllEqual( [2, 3, 4, 5, 6, 7], nest.flatten(structure1_plus1)) structure1_plus_structure2 = nest.map_structure( lambda x, y: x + y, structure1, structure2) self.assertEqual( (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), structure1_plus_structure2) self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) with self.assertRaisesRegexp(TypeError, "callable"): nest.map_structure("bad", structure1_plus1) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, 3, (3,)) with self.assertRaisesRegexp(TypeError, "same sequence type"): nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) structure1_list = [[[1, 2], 3], 4, [5, 6]] with self.assertRaisesRegexp(TypeError, "same sequence type"): nest.map_structure(lambda x, y: None, structure1, structure1_list) nest.map_structure(lambda x, y: None, structure1, structure1_list, check_types=False) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), check_types=False) with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, foo="a") with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
def __call__(self, *args): nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False) if not all([ shape.is_compatible_with(arg.shape) for shape, arg in zip(self.flattened_shapes, nest.flatten(args)) ]): raise ValueError( "Declared shapes do not match argument shapes: Expected %s, found %s." % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)])) initialized = [resource_variable_ops.var_is_initialized_op( v.handle).numpy() for v in self._call_fn.variables] if all(x for x in initialized): return self._call_fn(*args) elif all(not x for x in initialized): return self._init_fn(*args) else: raise ValueError("Some, but not all, variables are initialized.")
def testMapStructureWithStrings(self): inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) inp_b = NestTest.ABTuple(a=2, b=(1, 3)) out = nest.map_structure(lambda string, repeats: string * repeats, inp_a, inp_b) self.assertEqual("foofoo", out.a) self.assertEqual("bar", out.b[0]) self.assertEqual("bazbazbaz", out.b[1]) nt = NestTest.ABTuple(a=("something", "something_else"), b="yet another thing") rev_nt = nest.map_structure(lambda x: x[::-1], nt) # Check the output is the correct structure, and all strings are reversed. nest.assert_same_structure(nt, rev_nt) self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) self.assertEqual(nt.b[::-1], rev_nt.b)
def _verify_structure_shapes_types(self, left, right): """Verify that the structure, shapes and types of left are same as right.""" nest.assert_same_structure(left, right) flat_left = nest.flatten(left) flat_right = nest.flatten(right) assert len(flat_left) == len(flat_right), ( "Length of left {} and right {} should be same.". format(len(flat_left), len(flat_right))) for o, i in zip(flat_left, flat_right): # TODO(priyag): Add checks for other types like IndexedSlices. if isinstance(o, ops.Tensor): assert isinstance(i, ops.Tensor) assert o.shape == i.shape, ( "Shape {} of left {} doesn't match shape {} of right {}.". format(o.shape, o, i.shape, i)) assert o.dtype == i.dtype, ( "Dtype {} of left {} doesn't match dtype {} of right {}.". format(o.dtype, o, i.dtype, i))
def _check_same_outputs(true_graph, false_graph): """Raises an error if true_graph and false_graph have different outputs.""" true_output_types = [t.dtype for t in true_graph.outputs] false_output_types = [t.dtype for t in false_graph.outputs] if (len(true_graph.outputs) != len(false_graph.outputs) or true_output_types != false_output_types): raise TypeError( "true_fn() and false_fn() must return the same number and type of " "arguments, got:\n" " true_fn: %s\n" " false_fn: %s" % (true_output_types, false_output_types)) # Make sure `structured_outputs` for both graphs have the same structure. try: nest.assert_same_structure(true_graph.structured_outputs, false_graph.structured_outputs) except (ValueError, TypeError) as e: raise ValueError("Outputs of true_fn and false_fn must have the same " "structure: %s" % str(e))
def __call__(self, inputs, state, scope=None): """Run the cell and add its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ output, new_state = self._cell(inputs, state, scope=scope) nest.assert_same_structure(inputs, output) res_output = nest.map_structure( lambda inp, out: inp + out, inputs, output) return (res_output, new_state)
def compute(i, tas): """The loop body of map_fn. Args: i: the loop counter tas: the flat TensorArray accumulator list Returns: (i + 1, tas): the updated counter + updated TensorArrays Raises: TypeError: if dtype and packed_fn_values structure do not match ValueType: if dtype and packed_fn_values lengths do not match """ packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) packed_fn_values = fn(packed_values) nest.assert_same_structure(dtype or elems, packed_fn_values) flat_fn_values = output_flatten(packed_fn_values) tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)] return (i + 1, tas)
def assert_state_is_compatible(expected_state, state): """Asserts that states are compatible. Args: expected_state: The reference state. state: The state that must be compatible with :obj:`expected_state`. Raises: ValueError: if the states are incompatible. """ # Check structure compatibility. nest.assert_same_structure(expected_state, state) # Check shape compatibility. expected_state_flat = nest.flatten(expected_state) state_flat = nest.flatten(state) for x, y in zip(expected_state_flat, state_flat): if tensor_util.is_tensor(x): with_same_shape(x, y)
def gnmt_residual_fn(inputs, outputs): """Residual function that handles different inputs and outputs inner dims. Args: inputs: cell inputs, this is actual inputs concatenated with the attention vector. outputs: cell outputs Returns: outputs + actual inputs """ def split_input(inp, out): out_dim = out.get_shape().as_list()[-1] inp_dim = inp.get_shape().as_list()[-1] return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.assert_same_structure(actual_inputs, outputs) nest.map_structure(assert_shape_match, actual_inputs, outputs) return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)
def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) next_sequence_lengths = array_ops.where( math_ops.logical_and(math_ops.logical_not(finished), next_finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where(finished, cur, new) if impute_finished: next_state = nest.map_structure( _maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
def body(time, outputs_ta, state, inputs, finished): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: 1-D bool tensor. Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step( time, inputs, state) next_finished = math_ops.logical_or(decoder_finished, finished) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) # Copy through states past finish def _maybe_copy_state(new, cur): return (new if isinstance(cur, tensor_array_ops.TensorArray) else array_ops.where(finished, cur, new)) next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
def testInitialStateTuple(self, trainable, use_custom_initial_value, state_size): batch_size = 6 # Set the attribute to the class since it we can't set properties of # abstract classes snt.RNNCore.state_size = state_size flat_state_size = nest.flatten(state_size) core = snt.RNNCore(name="dummy_core") if use_custom_initial_value: flat_initializer = [tf.constant_initializer(2)] * len(flat_state_size) trainable_initializers = nest.pack_sequence_as( structure=state_size, flat_sequence=flat_initializer) else: trainable_initializers = None initial_state = core.initial_state( batch_size, dtype=tf.float32, trainable=trainable, trainable_initializers=trainable_initializers) nest.assert_same_structure(initial_state, state_size) flat_initial_state = nest.flatten(initial_state) for state, size in zip(flat_initial_state, flat_state_size): self.assertEqual(state.get_shape(), [batch_size, size]) with self.test_session() as sess: tf.global_variables_initializer().run() flat_initial_state_value = sess.run(flat_initial_state) for value, size in zip(flat_initial_state_value, flat_state_size): expected_initial_state = np.empty([batch_size, size]) if not trainable: expected_initial_state.fill(0) elif use_custom_initial_value: expected_initial_state.fill(2) else: value_row = value[0] expected_initial_state = np.tile(value_row, (batch_size, 1)) self.assertAllClose(value, expected_initial_state)
def testMapStructureOverPlaceholders(self): inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), array_ops.placeholder(dtypes.float32, shape=[3, 7])) inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), array_ops.placeholder(dtypes.float32, shape=[3, 7])) output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) nest.assert_same_structure(output, inp_a) self.assertShapeEqual(np.zeros((3, 4)), output[0]) self.assertShapeEqual(np.zeros((3, 7)), output[1]) feed_dict = { inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) } with self.cached_session() as sess: output_np = sess.run(output, feed_dict=feed_dict) self.assertAllClose(output_np[0], feed_dict[inp_a][0] + feed_dict[inp_b][0]) self.assertAllClose(output_np[1], feed_dict[inp_a][1] + feed_dict[inp_b][1])
def _build(self, inputs): """Passes inputs to the initial states of decoder. :attr:`inputs` must either have the same structure, or the same number of elements with the decoder state. Args: inputs: The input (structure of) tensors to pass forward. Returns: The input (structure of) tensors that might be re-packed to have the same structure with decoder state. """ output = inputs try: nest.assert_same_structure(inputs, self._output_size) except (ValueError, TypeError): flat_input = nest.flatten(inputs) output = nest.pack_sequence_as(self._output_size, flat_input) self._built = True return output
def compute(i, tas): """The loop body of map_fn. Args: i: the loop counter tas: the flat TensorArray accumulator list Returns: (i + 1, tas): the updated counter + updated TensorArrays Raises: TypeError: if dtype and packed_fn_values structure do not match ValueType: if dtype and packed_fn_values lengths do not match """ packed_values = input_pack( [elem_ta.read(i) for elem_ta in elems_ta]) packed_fn_values = fn(packed_values) nest.assert_same_structure(dtype or elems, packed_fn_values) flat_fn_values = output_flatten(packed_fn_values) tas = [ ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values) ] return (i + 1, tas)
def gnmt_residual_fn(inputs, outputs): """Residual function that handles different inputs and outputs inner dims. Args: inputs: cell inputs, this is actual inputs concatenated with the attention vector. refer to GNMT RNN cell outputs: cell outputs Returns: outputs + actual inputs """ def split_input(inp, out): out_dim = out.get_shape().as_list()[-1] inp_dim = inp.get_shape().as_list()[-1] return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.assert_same_structure(actual_inputs, outputs) nest.map_structure(assert_shape_match, actual_inputs, outputs) return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)
def __call__(self, inputs, state, scope=None): """Run the cell and add its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) nest.assert_same_structure(inputs, outputs) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.map_structure(assert_shape_match, inputs, outputs) res_outputs = nest.map_structure(lambda inp, out: inp + out, inputs, outputs) return (res_outputs, new_state)
def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes, check_all_flat=False): """Adds 1.0 as sample weights for the outputs for which there is no weight. Args: outputs: List of model outputs. sample_weights: List of sample weight inputs. sample_weight_modes: List of sample weight modes or None. check_all_flat: Ensure that inputs are not nested structures. This is not a free check, so we may not want to run it eagerly every iteration. Returns: Tuple of sample weights, one sample weight for every output, and booleans describing the raw sample weights. """ any_sample_weight = sample_weights is not None and any( w is not None for w in sample_weights) partial_sample_weight = any_sample_weight and any(w is None for w in sample_weights) if not any_sample_weight: return None, any_sample_weight, partial_sample_weight if not partial_sample_weight: return sample_weights, any_sample_weight, partial_sample_weight if check_all_flat: nest.assert_same_structure(list_to_tuple(sample_weights), list_to_tuple(nest.flatten(sample_weights))) nest.assert_same_structure(list_to_tuple(outputs), list_to_tuple(nest.flatten(outputs))) if sample_weight_modes is not None: nest.assert_same_structure(sample_weight_modes, nest.flatten(sample_weight_modes)) new_sample_weights = [] for i, sw in enumerate(sample_weights): if sw is None: as_numpy = isinstance(outputs[i], np.ndarray) output = outputs[i] output_shape = output.shape if as_numpy else array_ops.shape( output) is_temporal = (sample_weight_modes is not None and sample_weight_modes[i] == 'temporal') sw_shape = (output_shape[0], output_shape[1]) if is_temporal else ( output_shape[0], ) new_sample_weights.append( np.ones(sw_shape) if as_numpy else array_ops.ones(sw_shape)) else: new_sample_weights.append(sw) return (list_to_tuple(new_sample_weights), any_sample_weight, partial_sample_weight)
def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False try: # Verify that no input elements were dropped during flattening. repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) # TODO(b/129422719): Namedtuple subclasses re-created through # saved_model.load don't compare equal in type to the original in # assert_same_structure. Fix that and we can take out check_types=False # here. nest.assert_same_structure(inputs, repacked, check_types=False) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False elif isinstance(expected, type_spec.TypeSpec): return expected.is_compatible_with(arg) elif (_is_tensor(arg) and id(arg) != id(expected)) or (not _is_tensor(arg) and arg != expected): return False return True
def __init__(self, initial_state, mask=None, name="trainable_initial_state"): """Constructs the Module that introduces a trainable state in the graph. It receives an initial state that will be used as the initial values for the trainable variables that the module contains, and optionally a mask that indicates the parts of the initial state that should be learnable. Args: initial_state: tensor or arbitrarily nested iterables of tensors. mask: optional boolean mask. It should have the same nested structure as the given initial_state. name: module name. Raises: TypeError: if mask is not a list of booleans or None. """ super(TrainableInitialState, self).__init__(name=name) # Since python 2.7, DeprecationWarning is ignored by default. # Turn on the warning: warnings.simplefilter("always", DeprecationWarning) warnings.warn("Use the trainable flag in initial_state instead.", DeprecationWarning, stacklevel=2) if mask is not None: flat_mask = nest.flatten(mask) if not all([isinstance(m, bool) for m in flat_mask]): raise TypeError( "Mask should be None or a list of boolean values.") nest.assert_same_structure(initial_state, mask) self._mask = mask self._initial_state = initial_state
def test_state_override(self): test_start_state = (numpy.array([[2, 3, 4]]), (numpy.array([2]), numpy.array([[3., 5.]]))) data = { feature_keys.FilteringFeatures.TIMES: numpy.arange(5), feature_keys.FilteringFeatures.VALUES: numpy.zeros(shape=[5, 3]) } features, _ = input_pipeline.WholeDatasetInputFn( input_pipeline.NumpyReader(data))() features[feature_keys.FilteringFeatures.STATE_TUPLE] = test_start_state stub_model = _StateOverrideModel() chainer = state_management.ChainingStateManager() stub_model.initialize_graph() chainer.initialize_graph(model=stub_model) model_outputs = chainer.define_loss(model=stub_model, features=features, mode=estimator_lib.ModeKeys.EVAL) with train.MonitoredSession() as session: end_state = session.run(model_outputs.end_state) nest.assert_same_structure(test_start_state, end_state) for expected, received in zip(nest.flatten(test_start_state), nest.flatten(end_state)): self.assertAllEqual(expected, received)
def compute(i, a_flat, tas): """The loop body of scan. Args: i: the loop counter. a_flat: the accumulator value(s), flattened. tas: the output accumulator TensorArray(s), flattened. Returns: [i + 1, a_flat, tas]: the updated counter + new accumulator values + updated TensorArrays Raises: TypeError: if initializer and fn() output structure do not match ValueType: if initializer and fn() output lengths do not match """ packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) packed_a = output_pack(a_flat) a_out = fn(packed_a, packed_elems) nest.assert_same_structure( elems if initializer is None else initializer, a_out) flat_a_out = output_flatten(a_out) tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] return (i + 1, flat_a_out, tas)
def testFromFunctionInputSignatureForPerReplicaValues( self, distribution, enable_get_next_as_optional, drop_remainder): # Create files that produce partial/empty batches at different batch. Note # that some worker will get empty batches even when drop_remainder=True. fname1 = os.path.join(self.get_temp_dir(), "1.txt") _create_text_file(fname1, 5) fname2 = os.path.join(self.get_temp_dir(), "2.txt") _create_text_file(fname2, 9) def dataset_fn(input_context): dataset = dataset_ops.DatasetV2.from_tensor_slices( [fname1, fname2]) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) return readers.TextLineDatasetV2(dataset).map( string_ops.string_to_number).batch( input_context.get_per_replica_batch_size(4), drop_remainder=drop_remainder) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) ds = distribution.experimental_distribute_datasets_from_function( dataset_fn) _check_type_spec_structure(iter(ds)) element_spec = ds.element_spec iter_element_spec = iter(ds).element_spec nest.assert_same_structure(element_spec, iter_element_spec) self.assertAllEqual(nest.flatten(element_spec), nest.flatten(iter_element_spec)) @def_function.function(input_signature=[element_spec]) def process_inputs(inputs): distribution.run(lambda inputs: inputs, args=(inputs, )) for x in ds: process_inputs(x)
def _make_bridging_callable( generator_fn, wrap_in_tuple, peek, elements_to_keep, partial_sample_weight, sample_weight_modes): """Optional compatibility layer between user's data and Dataset.""" must_prune_nones = (elements_to_keep != len(peek)) try: nest.assert_same_structure(peek, nest._list_to_tuple(peek)) # pylint: disable=protected-access must_extract_lists = False except TypeError: must_extract_lists = True # No additional transformations are needed. if not (wrap_in_tuple or must_extract_lists or must_prune_nones or partial_sample_weight): return generator_fn def wrapped_generator(): """Remove Nones and lists before invoking Dataset.from_generator.""" for batch in generator_fn(): if wrap_in_tuple: batch = (batch,) if must_extract_lists: batch = nest._list_to_tuple(batch) # pylint: disable=protected-access if must_prune_nones: batch = batch[:elements_to_keep] if partial_sample_weight: sample_weights, _, _ = training_utils.handle_partial_sample_weights( batch[1], batch[2], sample_weight_modes, check_all_flat=False) batch = batch[:2] + (sample_weights,) yield batch return wrapped_generator
def _build(self, inputs): """Transforms inputs to have the same structure as with :attr:`output_size`. Values of the inputs are not changed. :attr:`inputs` must either have the same structure, or have the same number of elements with :attr:`output_size`. Args: inputs: The input (structure of) tensor to pass forward. Returns: A (structure of) tensors that re-packs `inputs` to have the specified structure of `output_size`. """ output = inputs try: nest.assert_same_structure(inputs, self._output_size) except (ValueError, TypeError): flat_input = nest.flatten(inputs) output = nest.pack_sequence_as(self._output_size, flat_input) self._built = True return output
def body(time, outputs_ta, state, inputs, finished): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: 1-D bool tensor. Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) #判断结构是否一样 nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where( finished, cur, new) if impute_finished: next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
def body(time, elements_finished, current_input, state, beam_seq, beam_prob, emit_ta, loop_state): """Internal while loop body for raw_rnn. Args: time: time scalar. elements_finished: batch-size vector. current_input: possibly nested tuple of input tensors. emit_ta: possibly nested tuple of output TensorArrays. state: possibly nested tuple of state tensors. loop_state: possibly nested tuple of loop state tensors. Returns: Tuple having the same size as Args but with updated values. """ dummy = array_ops.zeros( shape=[tf.shape(beam_seq)[0], tf.shape(beam_seq)[1], 20], dtype=tf.int32) (next_output, cell_state) = cell(current_input, state) nest.assert_same_structure(state, cell_state) nest.assert_same_structure(cell.output_size, next_output) #cell_output, cell_state, beam_seq, beam_prob, finished, emit_ta, loop_state (next_input, elements_finished, next_state, beam_seq, beam_prob, emit_ta, next_loop_state) = loop_fn(next_output, cell_state, beam_seq, beam_prob, elements_finished, emit_ta, loop_state) nest.assert_same_structure(state, next_state) nest.assert_same_structure(current_input, next_input) # If loop_fn returns None for next_loop_state, just reuse the # previous one. loop_state = loop_state if next_loop_state is None else next_loop_state next_time = time + 1 return (next_time, elements_finished, next_input, next_state, beam_seq, beam_prob, emit_ta, loop_state)
def body( time, outputs_ta, state, inputs, finished, sequence_lengths ): next_outputs, next_state, next_inputs, decoder_finished = decoder.step(time, inputs, state) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = tf.logical_or(decoder_finished, finished) next_finished = tf.reshape(next_finished, [-1]) #reshape이유 1: helper에서 cond에 들어가면 merge가 됨, 2: inference시에 2차원 값이 나옴 next_sequence_lengths = tf.where( tf.logical_not(finished), x= tf.fill(tf.shape(sequence_lengths), time + 1), y= sequence_lengths ) nest.assert_same_structure(state, next_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) if impute_finished: new_linear = nest.map_structure( lambda out, zero: tf.where(finished, zero, out), next_outputs.linear, tf.zeros_like(next_outputs.linear) ) next_outputs._replace(linear= new_linear) def _maybe_copy_state(new, cur): if isinstance(cur, tf.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else tf.where(finished, cur, new) next_state = nest.map_structure(_maybe_copy_state, next_state, state) outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, next_outputs) return time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths
def _verify_tf_loop_vars(init_vars, iter_entry_vars, iter_exit_vars, symbol_names, opts, check_shapes=True): """Verifies loop variables for consistency.""" if check_shapes and 'shape_invariants' in opts: shape_invariants = opts['shape_invariants'] else: shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) assert len(symbol_names) == len(shape_invariants) assert len(symbol_names) == len(init_vars) assert len(symbol_names) == len(iter_entry_vars) assert len(symbol_names) == len(iter_exit_vars) for i in range(len(symbol_names)): name = symbol_names[i] init = init_vars[i] entry = iter_entry_vars[i] exit_ = iter_exit_vars[i] invariant = shape_invariants[i] try: nest.assert_same_structure(init, entry, expand_composites=True) nest.assert_same_structure(entry, exit_, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( "'{}' does not have the same nested structure after one" ' iteration.\n\n{}'.format(name, e)) if invariant is not None: try: nest.assert_same_structure(init, invariant, expand_composites=False) except (ValueError, TypeError) as e: raise TypeError( "'{}' does not have the same nested structure as its" ' corresponding shape invariant.\n\n{}'.format(name, e)) nest.map_structure( functools.partial(_verify_single_loop_var, name, check_shapes), init, entry, exit_, invariant)
def testSameStructure(self): d = {1: "a"} nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
def testSameStructure(self): l = [1] nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))
def testNestAssertSameStructure(self, s1, s2, expand_composites=True): nest.assert_same_structure(s1, s2, expand_composites=expand_composites) nest.assert_shallow_structure(s1, s2, expand_composites=expand_composites)
def decoder(self, encoder_state, attn_features, prediction_inputs, previous_y): """ :param encoder_state: shape [batch_size, encoder_rnn_depth] :param prediction_inputs: features for prediction days, tensor[batch_size, time, input_depth] :param previous_y: Last day pageviews, shape [batch_size] :param attn_features: Additional features from attention layer, shape [batch, predict_window, readout_depth*n_heads] :return: decoder rnn output """ hparams = self.hparams def build_cell(idx): with tf.variable_scope('decoder_cell', initializer=self.default_init(idx)): cell = rnn.GRUBlockCell(self.hparams.rnn_depth) has_dropout = hparams.decoder_input_dropout[idx] < 1 \ or hparams.decoder_state_dropout[idx] < 1 or hparams.decoder_output_dropout[idx] < 1 if self.is_train and has_dropout: attn_depth = attn_features.shape[ -1].value if attn_features is not None else 0 input_size = attn_depth + prediction_inputs.shape[ -1].value + 1 if idx == 0 else self.hparams.rnn_depth cell = rnn.DropoutWrapper( cell, dtype=tf.float32, input_size=input_size, variational_recurrent=hparams. decoder_variational_dropout[idx], input_keep_prob=hparams.decoder_input_dropout[idx], output_keep_prob=hparams.decoder_output_dropout[idx], state_keep_prob=hparams.decoder_state_dropout[idx], seed=self.seed + idx) return cell if hparams.decoder_rnn_layers > 1: cells = [ build_cell(idx) for idx in range(hparams.decoder_rnn_layers) ] cell = rnn.MultiRNNCell(cells) else: cell = build_cell(0) nest.assert_same_structure(encoder_state, cell.state_size) predict_days = self.inp.predict_window assert prediction_inputs.shape[1] == predict_days # [batch_size, time, input_depth] -> [time, batch_size, input_depth] inputs_by_time = tf.transpose(prediction_inputs, [1, 0, 2]) # Return raw outputs for RNN losses calculation return_raw_outputs = self.hparams.decoder_stability_loss > 0.0 or self.hparams.decoder_activation_loss > 0.0 # Stop condition for decoding loop def cond_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): return time < predict_days # FC projecting layer to get single predicted value from RNN output def project_output(tensor): return tf.layers.dense(tensor, 1, name='decoder_output_proj', kernel_initializer=self.default_init()) def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): """ Main decoder loop :param time: Day number :param prev_output: Output(prediction) from previous step :param prev_state: RNN state tensor from previous step :param array_targets: Predictions, each step will append new value to this array :param array_outputs: Raw RNN outputs (for regularization losses) :return: """ # RNN inputs for current step features = inputs_by_time[time] # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads] if attn_features is not None: # [batch_size, 1] + [batch_size, input_depth] attn = attn_features[:, time, :] # Append previous predicted value + attention vector to input features next_input = tf.concat([prev_output, features, attn], axis=1) else: # Append previous predicted value to input features next_input = tf.concat([prev_output, features], axis=1) # Run RNN cell output, state = cell(next_input, prev_state) # Make prediction from RNN outputs projected_output = project_output(output) # Append step results to the buffer arrays if return_raw_outputs: array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) # Increment time and return return time + 1, projected_output, state, array_targets, array_outputs # Initial values for loop loop_init = [ tf.constant(0, dtype=tf.int32), tf.expand_dims(previous_y, -1), encoder_state, tf.TensorArray(dtype=tf.float32, size=predict_days), tf.TensorArray(dtype=tf.float32, size=predict_days) if return_raw_outputs else tf.constant(0) ] # Run the loop _, _, _, targets_ta, outputs_ta = tf.while_loop( cond_fn, loop_fn, loop_init) # Get final tensors from buffer arrays targets = targets_ta.stack() # [time, batch_size, 1] -> [time, batch_size] targets = tf.squeeze(targets, axis=-1) raw_outputs = outputs_ta.stack() if return_raw_outputs else None return targets, raw_outputs
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): """ raw_rnn adapted from the original tensorflow implementation (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py) to emit arbitrarily nested states for each time step (concatenated along the time axis) in addition to the outputs at each timestep and the final state returns ( states for all timesteps, outputs for all timesteps, final cell state, ) """ if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") if not callable(loop_fn): raise TypeError("loop_fn must be a callable") parallel_iterations = parallel_iterations or 32 # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: if not context.executing_eagerly(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) time = constant_op.constant(0, dtype=dtypes.int32) (elements_finished, next_input, initial_state, emit_structure, init_loop_state) = loop_fn(time, None, None, None) flat_input = nest.flatten(next_input) # Need a surrogate loop state for the while_loop if none is available. loop_state = (init_loop_state if init_loop_state is not None else constant_op.constant(0, dtype=dtypes.int32)) input_shape = [input_.get_shape() for input_ in flat_input] static_batch_size = input_shape[0][0] for input_shape_i in input_shape: # Static verification that batch sizes all match static_batch_size.merge_with(input_shape_i[0]) batch_size = static_batch_size.value const_batch_size = batch_size if batch_size is None: batch_size = array_ops.shape(flat_input[0])[0] nest.assert_same_structure(initial_state, cell.state_size) state = initial_state flat_state = nest.flatten(state) flat_state = [ops.convert_to_tensor(s) for s in flat_state] state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) if emit_structure is not None: flat_emit_structure = nest.flatten(emit_structure) flat_emit_size = [ emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) for emit in flat_emit_structure ] flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] else: emit_structure = cell.output_size flat_emit_size = nest.flatten(emit_structure) flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) flat_state_size = [ s.shape if s.shape.is_fully_defined() else array_ops.shape(s) for s in flat_state ] flat_state_dtypes = [s.dtype for s in flat_state] flat_emit_ta = [ tensor_array_ops.TensorArray( dtype=dtype_i, dynamic_size=True, element_shape=(tensor_shape.TensorShape([ const_batch_size ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), size=0, name="rnn_output_%d" % i) for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) ] emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta) flat_zero_emit = [ array_ops.zeros(_concat(batch_size, size_i), dtype_i) for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes) ] zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit) flat_state_ta = [ tensor_array_ops.TensorArray( dtype=dtype_i, dynamic_size=True, element_shape=(tensor_shape.TensorShape([ const_batch_size ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), size=0, name="rnn_state_%d" % i) for i, ( dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size)) ] state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta) def condition(unused_time, elements_finished, *_): return math_ops.logical_not(math_ops.reduce_all(elements_finished)) def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state): (next_output, cell_state) = cell(current_input, state) nest.assert_same_structure(state, cell_state) nest.assert_same_structure(cell.output_size, next_output) next_time = time + 1 (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state) nest.assert_same_structure(state, next_state) nest.assert_same_structure(current_input, next_input) nest.assert_same_structure(emit_ta, emit_output) # If loop_fn returns None for next_loop_state, just reuse the previous one. loop_state = loop_state if next_loop_state is None else next_loop_state def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" def copy_fn(cur_i, cand_i): # TensorArray and scalar get passed through. if isinstance(cur_i, tensor_array_ops.TensorArray): return cand_i if cur_i.shape.ndims == 0: return cand_i # Otherwise propagate the old or the new value. with ops.colocate_with(cand_i): return array_ops.where(elements_finished, cur_i, cand_i) return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) state_ta = nest.map_structure( lambda ta, state: ta.write(time, state), state_ta, next_state) elements_finished = math_ops.logical_or(elements_finished, next_finished) return (next_time, elements_finished, next_input, state_ta, emit_ta, next_state, loop_state) returned = control_flow_ops.while_loop( condition, body, loop_vars=[ time, elements_finished, next_input, state_ta, emit_ta, state, loop_state ], parallel_iterations=parallel_iterations, swap_memory=swap_memory) (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:] flat_states = nest.flatten(state_ta) flat_states = [ array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states ] states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states) flat_outputs = nest.flatten(emit_ta) flat_outputs = [ array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs ] outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs) return (states, outputs, final_state)
def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step( time, inputs, state ) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = math_ops.logical_or(decoder_finished, finished) next_sequence_lengths = array_ops.where( math_ops.logical_not(finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths, ) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs, ) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = new.shape.ndims == 0 return new if pass_through else array_ops.where(finished, cur, new) if impute_finished: next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta, emit ) return ( time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths, )
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True, back_prop=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, expand_composites=True) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants, expand_composites=False) signature = nest.map_structure( control_flow_ops._shape_invariant_to_type_spec, loop_vars, list(shape_invariants), expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, list(shape_invariants), expand_composites=False) else: signature = nest.map_structure( type_spec.type_spec_from_value, loop_vars, expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, expand_composites=False) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants signature = ( [tensor_spec.TensorSpec.from_tensor(loop_counter), tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + signature) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies def wrapped_cond(loop_counter, maximum_iterations_arg, *args): """Extra `cond` wrapper that can handle the extra counter loop_var.""" # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. pred = cond(*_pack_sequence_as(orig_loop_vars, args)) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) if maximum_iterations is None: return pred else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, pred) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) cond_graph_captures = object_identity.ObjectIdentitySet( cond_graph.external_captures) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph_captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars, expand_composites=True)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True)) num_original_outputs = len(body_graph.outputs) if back_prop and util.output_all_intermediates(): # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): output_shapes = [t.shape for t in body_graph.outputs] orig_loop_vars_range = slice(first_loop_var_index, first_loop_var_index + num_flattened_outputs) output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] cond_stateful_ops = [ op for op in cond_graph.get_operations() if op._is_stateful ] body_stateful_ops = [ op for op in body_graph.get_operations() if op._is_stateful ] if (cond_stateful_ops or body_stateful_ops): op_fn = gen_functional_ops._while else: op_fn = gen_functional_ops.stateless_while outputs = op_fn( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=scope) # This is needed so we do not compute derivative wrt these extra outputs. outputs[0].op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=num_original_outputs)) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs, expand_composites=True) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, zero_output, state, call_cell, state_size, skip_conditionals=False): """Calculate one step of a dynamic RNN minibatch. Returns an (output, state) pair conditioned on the sequence_lengths. When skip_conditionals=False, the pseudocode is something like: if t >= max_sequence_length: return (zero_output, state) if t < min_sequence_length: return call_cell() # Selectively output zeros or output, old state or new state depending # on if we've finished calculating each row. new_output, new_state = call_cell() final_output = np.vstack([ zero_output if time >= sequence_lengths[r] else new_output_r for r, new_output_r in enumerate(new_output) ]) final_state = np.vstack([ state[r] if time >= sequence_lengths[r] else new_state_r for r, new_state_r in enumerate(new_state) ]) return (final_output, final_state) Args: time: Python int, the current time step sequence_length: int32 `Tensor` vector of size [batch_size] min_sequence_length: int32 `Tensor` scalar, min of sequence_length max_sequence_length: int32 `Tensor` scalar, max of sequence_length zero_output: `Tensor` vector of shape [output_size] state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, or a list/tuple of such tensors. call_cell: lambda returning tuple of (new_output, new_state) where new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. state_size: The `cell.state_size` associated with the state. skip_conditionals: Python bool, whether to skip using the conditional calculations. This is useful for `dynamic_rnn`, where the input tensor matches `max_sequence_length`, and using conditionals just slows everything down. Returns: A tuple of (`final_output`, `final_state`) as given by the pseudocode above: final_output is a `Tensor` matrix of shape [batch_size, output_size] final_state is either a single `Tensor` matrix, or a tuple of such matrices (matching length and shapes of input `state`). Raises: ValueError: If the cell returns a state tuple whose length does not match that returned by `state_size`. """ # Convert state to a list for ease of use flat_state = nest.flatten(state) flat_zero_output = nest.flatten(zero_output) def _copy_one_through(output, new_output): # If the state contains a scalar value we simply pass it through. if output.shape.ndims == 0: return new_output copy_cond = (time >= sequence_length) with ops.colocate_with(new_output): return array_ops.where(copy_cond, output, new_output) def _copy_some_through(flat_new_output, flat_new_state): # Use broadcasting select to determine which values should get # the previous state & zero output, and which values should get # a calculated state & output. flat_new_output = [ _copy_one_through(zero_output, new_output) for zero_output, new_output in zip(flat_zero_output, flat_new_output)] flat_new_state = [ _copy_one_through(state, new_state) for state, new_state in zip(flat_state, flat_new_state)] return flat_new_output + flat_new_state def _maybe_copy_some_through(): """Run RNN step. Pass through either no or some past state.""" new_output, new_state = call_cell() nest.assert_same_structure(state, new_state) flat_new_state = nest.flatten(new_state) flat_new_output = nest.flatten(new_output) return control_flow_ops.cond( # if t < min_seq_len: calculate and return everything time < min_sequence_length, lambda: flat_new_output + flat_new_state, # else copy some of it through lambda: _copy_some_through(flat_new_output, flat_new_state)) # TODO(ebrevdo): skipping these conditionals may cause a slowdown, # but benefits from removing cond() and its gradient. We should # profile with and without this switch here. if skip_conditionals: # Instead of using conditionals, perform the selective copy at all time # steps. This is faster when max_seq_len is equal to the number of unrolls # (which is typical for dynamic_rnn). new_output, new_state = call_cell() nest.assert_same_structure(state, new_state) new_state = nest.flatten(new_state) new_output = nest.flatten(new_output) final_output_and_state = _copy_some_through(new_output, new_state) else: empty_update = lambda: flat_zero_output + flat_state final_output_and_state = control_flow_ops.cond( # if t >= max_seq_len: copy all state through, output zeros time >= max_sequence_length, empty_update, # otherwise calculation is required: copy some or all of it through _maybe_copy_some_through) if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): raise ValueError("Internal error: state and output were not concatenated " "correctly.") final_output = final_output_and_state[:len(flat_zero_output)] final_state = final_output_and_state[len(flat_zero_output):] for output, flat_output in zip(final_output, flat_zero_output): output.set_shape(flat_output.get_shape()) for substate, flat_substate in zip(final_state, flat_state): substate.set_shape(flat_substate.get_shape()) final_output = nest.pack_sequence_as( structure=zero_output, flat_sequence=final_output) final_state = nest.pack_sequence_as( structure=state, flat_sequence=final_state) return final_output, final_state
def _assert_sructures_equal(self, struct1, struct2): tf_nest.assert_same_structure(struct1, struct2) for a, b in zip(tf_nest.flatten(struct1), tf_nest.flatten(struct2)): np.testing.assert_array_equal(a, b)
def testAssertSameStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) structure_different_num_elements = ("spam", "eggs") structure_different_nesting = (((1, 2), 3), 4, 5, (6, )) nest.assert_same_structure(structure1, structure2) nest.assert_same_structure("abc", 1.0) nest.assert_same_structure("abc", np.array([0, 1])) nest.assert_same_structure("abc", constant_op.constant([0, 1])) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure(structure1, structure_different_num_elements) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure([0, 1], np.array([0, 1])) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure(0, [0, 1]) self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(structure1, structure_different_nesting) named_type_0 = collections.namedtuple("named_0", ("a", "b")) named_type_1 = collections.namedtuple("named_1", ("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), named_type_0("a", "b")) nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, named_type_0(3, 4), named_type_1(3, 4)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure([[3], 4], [3, [4]])
def _testWithMaybeMultiAttention(self, is_multi, create_attention_mechanisms, expected_final_output, expected_final_state, attention_mechanism_depths, alignment_history=False, expected_final_alignment_history=None, attention_layer_sizes=None, attention_layers=None, create_query_layer=False, create_memory_layer=True, create_attention_kwargs=None): # Allow is_multi to be True with a single mechanism to enable test for # passing in a single mechanism in a list. assert len(create_attention_mechanisms) == 1 or is_multi encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 create_attention_kwargs = create_attention_kwargs or {} if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. attention_depth = sum( attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( attention_layer.compute_output_shape( [batch_size, cell_depth + encoder_output_depth]).dims[-1].value for attention_layer in attention_layers) else: attention_depth = encoder_output_depth * len( create_attention_mechanisms) decoder_inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) encoder_outputs = np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype( np.float32) attention_mechanisms = [] for creator, depth in zip(create_attention_mechanisms, attention_mechanism_depths): # Create a memory layer with deterministic initializer to avoid randomness # in the test between graph and eager. if create_query_layer: create_attention_kwargs["query_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) if create_memory_layer: create_attention_kwargs["memory_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) attention_mechanisms.append( creator(units=depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, **create_attention_kwargs)) with self.cached_session(use_gpu=True): attention_layer_size = attention_layer_sizes attention_layer = attention_layers if not is_multi: if attention_layer_size is not None: attention_layer_size = attention_layer_size[0] if attention_layer is not None: attention_layer = attention_layer[0] cell = keras.layers.LSTMCell(cell_depth, recurrent_activation="sigmoid", kernel_initializer="ones", recurrent_initializer="ones") cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], attention_layer_size=attention_layer_size, alignment_history=alignment_history, attention_layer=attention_layer) if cell._attention_layers is not None: for layer in cell._attention_layers: if getattr(layer, "kernel_initializer") is None: layer.kernel_initializer = initializers.glorot_uniform( seed=1337) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) initial_state = cell.get_initial_state(dtype=dtypes.float32, batch_size=batch_size) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=initial_state, sequence_length=decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) expected_time = (expected_final_state.time if context.executing_eagerly() else None) self.assertEqual( (batch_size, expected_time, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( (batch_size, expected_time), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual( (batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state[0].get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state[1].get_shape().as_list())) if alignment_history: if is_multi: state_alignment_history = [] for history_array in final_state.alignment_history: history = history_array.stack() self.assertEqual( (expected_time, batch_size, encoder_max_time), tuple(history.get_shape().as_list())) state_alignment_history.append(history) state_alignment_history = tuple(state_alignment_history) else: state_alignment_history = final_state.alignment_history.stack( ) self.assertEqual( (expected_time, batch_size, encoder_max_time), tuple(state_alignment_history.get_shape().as_list())) nest.assert_same_structure( cell.state_size, cell.zero_state(batch_size, dtypes.float32)) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access else: state_alignment_history = () self.evaluate(variables.global_variables_initializer()) eval_result = self.evaluate({ "final_outputs": final_outputs, "final_state": final_state, "state_alignment_history": state_alignment_history, }) final_output_info = nest.map_structure( get_result_summary, eval_result["final_outputs"]) final_state_info = nest.map_structure(get_result_summary, eval_result["final_state"]) print("final_output_info: ", final_output_info) print("final_state_info: ", final_state_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, final_output_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, final_state_info) if alignment_history: # by default, the wrapper emits attention as output final_alignment_history_info = nest.map_structure( get_result_summary, eval_result["state_alignment_history"]) print("final_alignment_history_info: ", final_alignment_history_info) nest.map_structure( self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is time major expected_final_alignment_history, final_alignment_history_info)
def while_loop(cond, body, loop_vars, shape_invariants=None, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + cond_graph.external_captures shape_invariants = shape_invariants + type(shape_invariants)( [t.shape for t in cond_graph.external_captures]) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:]) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(b/118457764): Dedup tensors that are captured in both the cond and # body. This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: assert external_capture not in cond_graph.captures, ( "Looks like both cond and body are capturing the same tensor %s. " "This is not supported yet. For now consider passing," " this as a loop variable." % str(external_capture)) cond_graph.capture(external_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs