def _call_cell(self, inputs, initial_cell_state=None, initial_output=None, dtype=None, sequence_length=None): """Run this LSTM on inputs, starting from the given state. Args: inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` initial_cell_state: initial value for cell state, shape `[batch_size, self._num_units]` initial_output: initial value of cell output, shape `[batch_size, self._num_units]` dtype: The data type for the initial state and expected output. sequence_length: Specifies the length of each sequence in inputs. An `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, time_len)` or None. Returns: A pair containing: - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size, output_size]` - Output (h): A `3-D` tensor of shape `[time_len, batch_size, output_size]` """ inputs_shape = inputs.get_shape().with_rank(3) time_len = inputs_shape[0].value if time_len is None: time_len = array_ops.shape(inputs)[0] if self._use_peephole: wci = self._w_i_diag wco = self._w_o_diag wcf = self._w_f_diag else: wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: max_seq_len = math_ops.to_int64(time_len) else: max_seq_len = math_ops.to_int64( math_ops.reduce_max(sequence_length)) _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm( seq_len_max=max_seq_len, x=inputs, cs_prev=initial_cell_state, h_prev=initial_output, w=self._kernel, wci=wci, wcf=wcf, wco=wco, b=self._bias, forget_bias=self._forget_bias, cell_clip=self._cell_clip, use_peephole=self._use_peephole) return cs, h
def _call_cell(self, inputs, initial_cell_state=None, initial_output=None, dtype=None, sequence_length=None): """Run this LSTM on inputs, starting from the given state. Args: inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` initial_cell_state: initial value for cell state, shape `[batch_size, self._num_units]` initial_output: initial value of cell output, shape `[batch_size, self._num_units]` dtype: The data type for the initial state and expected output. sequence_length: Specifies the length of each sequence in inputs. An `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, time_len)` or None. Returns: A pair containing: - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size, output_size]` - Output (h): A `3-D` tensor of shape `[time_len, batch_size, output_size]` """ inputs_shape = inputs.get_shape().with_rank(3) time_len = inputs_shape.dims[0].value if time_len is None: time_len = array_ops.shape(inputs)[0] if self._use_peephole: wci = self._w_i_diag wco = self._w_o_diag wcf = self._w_f_diag else: wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: max_seq_len = math_ops.cast(time_len, dtypes.int64) else: max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length), dtypes.int64) _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm( seq_len_max=max_seq_len, x=inputs, cs_prev=initial_cell_state, h_prev=initial_output, w=self._kernel, wci=wci, wcf=wcf, wco=wco, b=self._bias, forget_bias=self._forget_bias, cell_clip=self._cell_clip, use_peephole=self._use_peephole) return cs, h
def _call_cell(self, inputs, initial_cell_state, initial_output, dtype, sequence_length): """Run this LSTM on inputs, starting from the given state. Args: inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` initial_cell_state: initial value for cell state, shape `[batch_size, self._num_units]` initial_output: initial value of cell output, shape `[batch_size, self._num_units]` dtype: The data type for the initial state and expected output. sequence_length: Specifies the length of each sequence in inputs. An `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, time_len)` or None. Returns: A pair containing: - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size, output_size]` - Output (h): A `3-D` tensor of shape `[time_len, batch_size, output_size]` """ inputs_shape = inputs.get_shape().with_rank(3) time_len = inputs_shape[0].value if time_len is None: time_len = array_ops.shape(inputs)[0] input_size = inputs_shape[2].value w = vs.get_variable( "kernel", [input_size + self._num_units, self._num_units * 4], dtype=dtype) b = vs.get_variable( "bias", [w.get_shape().with_rank(2)[1]], initializer=init_ops.constant_initializer(0.0), dtype=dtype) if self._use_peephole: wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype) wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype) else: wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: max_seq_len = math_ops.to_int64(time_len) else: max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length)) _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm( seq_len_max=max_seq_len, x=inputs, cs_prev=initial_cell_state, h_prev=initial_output, w=w, wci=wci, wco=wco, wcf=wcf, b=b, forget_bias=self._forget_bias, cell_clip=self._cell_clip, use_peephole=self._use_peephole) return cs, h
def _block_lstm(seq_len_max, x, w, b, cs_prev=None, h_prev=None, wci=None, wcf=None, wco=None, forget_bias=None, cell_clip=None, use_peephole=None, name=None): r"""TODO(williamchan): add doc. Args: seq_len_max: A `Tensor` of type `int64`. x: A list of at least 1 `Tensor` objects of the same type in: `float32`. w: A `Tensor`. Must have the same type as `x`. b: A `Tensor`. Must have the same type as `x`. cs_prev: A `Tensor`. Must have the same type as `x`. h_prev: A `Tensor`. Must have the same type as `x`. wci: A `Tensor`. Must have the same type as `x`. wcf: A `Tensor`. Must have the same type as `x`. wco: A `Tensor`. Must have the same type as `x`. forget_bias: An optional `float`. Defaults to `1`. cell_clip: An optional `float`. Defaults to `3`. use_peephole: An optional `bool`. Defaults to `False`. name: A name for the operation (optional). Returns: A tuple of `Tensor` objects (i, cs, f, o, ci, co, h). i: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. cs: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. f: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. o: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. ci: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. co: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. h: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. Raises: ValueError: If `b` does not have a valid shape. """ batch_size = x[0].get_shape().with_rank(2)[0].value cell_size4 = b.get_shape().with_rank(1)[0].value if cell_size4 is None: raise ValueError("`b` shape must not be None.") cell_size = cell_size4 / 4 zero_state = None if cs_prev is None or h_prev is None: zero_state = array_ops.constant( 0, dtype=dtypes.float32, shape=[batch_size, cell_size]) if cs_prev is None: cs_prev = zero_state if h_prev is None: h_prev = zero_state if wci is None: wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) wco = wci wcf = wci # pylint: disable=protected-access i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm( seq_len_max=seq_len_max, x=array_ops.stack(x), cs_prev=cs_prev, h_prev=h_prev, w=w, wci=wci, wco=wco, wcf=wcf, b=b, forget_bias=forget_bias, cell_clip=cell_clip, name=name, use_peephole=use_peephole) return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack( f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack( co), array_ops.unstack(h)
def _block_lstm(seq_len_max, x, w, b, cs_prev=None, h_prev=None, wci=None, wcf=None, wco=None, forget_bias=None, cell_clip=None, use_peephole=None, name=None): r"""TODO(williamchan): add doc. Args: seq_len_max: A `Tensor` of type `int64`. x: A list of at least 1 `Tensor` objects of the same type in: `float32`. w: A `Tensor`. Must have the same type as `x`. b: A `Tensor`. Must have the same type as `x`. cs_prev: A `Tensor`. Must have the same type as `x`. h_prev: A `Tensor`. Must have the same type as `x`. wci: A `Tensor`. Must have the same type as `x`. wcf: A `Tensor`. Must have the same type as `x`. wco: A `Tensor`. Must have the same type as `x`. forget_bias: An optional `float`. Defaults to `1`. cell_clip: An optional `float`. Defaults to `-1` (no clipping). use_peephole: An optional `bool`. Defaults to `False`. name: A name for the operation (optional). Returns: A tuple of `Tensor` objects (i, cs, f, o, ci, co, h). i: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. cs: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. f: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. o: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. ci: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. co: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. h: A list with the same number of `Tensor` objects as `x` of `Tensor` objects of the same type as x. Raises: ValueError: If `b` does not have a valid shape. """ batch_size = x[0].get_shape().with_rank(2)[0].value cell_size4 = b.get_shape().with_rank(1)[0].value if cell_size4 is None: raise ValueError("`b` shape must not be None.") cell_size = cell_size4 / 4 zero_state = None if cs_prev is None or h_prev is None: zero_state = array_ops.constant(0, dtype=dtypes.float32, shape=[batch_size, cell_size]) if cs_prev is None: cs_prev = zero_state if h_prev is None: h_prev = zero_state if wci is None: wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) wcf = wci wco = wci # pylint: disable=protected-access i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm( seq_len_max=seq_len_max, x=array_ops.stack(x), cs_prev=cs_prev, h_prev=h_prev, w=w, wci=wci, wcf=wcf, wco=wco, b=b, forget_bias=forget_bias, cell_clip=cell_clip if cell_clip is not None else -1, name=name, use_peephole=use_peephole) return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack( f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack( co), array_ops.unstack(h)