def __init__(self, conv_ndims, input_shape, output_channels, kernel_shape, use_bias=True, skip_connection=False, forget_bias=1.0, initializers=None, reuse=None): """Construct ConvLSTMCell. Args: conv_ndims: Convolution dimensionality (1, 2 or 3). input_shape: Shape of the input as int tuple, excluding the batch size. output_channels: int, number of output channels of the conv LSTM. kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). use_bias: (bool) Use bias in convolutions. skip_connection: If set to `True`, concatenate the input to the output of the conv LSTM. Default: `False`. forget_bias: Forget bias. initializers: Unused. name: Name of the module. Raises: ValueError: If `skip_connection` is `True` and stride is different from 1 or if `input_shape` is incompatible with `conv_ndims`. """ super(ConvLSTMCell, self).__init__(_reuse=reuse) if conv_ndims != len(input_shape) - 1: raise ValueError("Invalid input_shape {} for conv_ndims={}.".format( input_shape, conv_ndims)) self._conv_ndims = conv_ndims self._input_shape = input_shape self._output_channels = output_channels self._kernel_shape = kernel_shape self._use_bias = use_bias self._forget_bias = forget_bias self._skip_connection = skip_connection self._reuse = reuse self._total_output_channels = output_channels if self._skip_connection: self._total_output_channels += self._input_shape[-1] state_size = tensor_shape.TensorShape( self._input_shape[:-1] + [self._output_channels]) self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) self._output_size = tensor_shape.TensorShape( self._input_shape[:-1] + [self._total_output_channels])
def call(self, inputs, state): sigmoid = math_ops.sigmoid # Parameters of gates are concatenated into one multiply for efficiency. c, h = state feat_out = self._num_units inputs_shape = inputs.get_shape() feat_in = int(inputs_shape[1]) // self.n_nodes scope = vs.get_variable_scope() with vs.variable_scope(scope): Wx = tf.get_variable("input_weights", [feat_in, feat_out], dtype=tf.float32, initializer=self._kernel_initializer) Wh = tf.get_variable("hidden_weights", [feat_out, feat_out], dtype=tf.float32, initializer=self._kernel_initializer) W = tf.get_variable("mutual_weights", [2 * feat_out, feat_out * 3], dtype=tf.float32, initializer=self._kernel_initializer) bias = tf.get_variable("biases", [feat_out * 3], dtype=tf.float32, initializer=self._bias_initializer) conv_inputs = tf.nn.relu( gcn(inputs, self.conv_matrix, Wx, None, feat_in, feat_out, self.n_nodes)) conv_hidden = tf.nn.relu( gcn(h, self.conv_matrix, Wh, None, feat_out, feat_out, self.n_nodes)) concat = array_ops.concat([conv_inputs, conv_hidden], 1) value = gcn(concat, self.conv_matrix, W, bias, 2 * feat_out, 3 * feat_out, self.n_nodes) value = tf.reshape(value, [-1, self.n_nodes, feat_out * 3]) value = tf.nn.bias_add(value, bias) value = tf.reshape(value, [-1, self.n_nodes * feat_out * 3]) i, j, o = array_ops.split(value=value, num_or_size_splits=3, axis=1) # tied gate i = sigmoid(i) new_c = (1 - i) * c + i * self._activation(j) output_gate = sigmoid(o) new_h = self._activation(new_c) * output_gate new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) if self.use_residual_connection: if new_h.get_shape().as_list() != inputs.get_shape().as_list(): return new_h + self.projection_fn( inputs) * output_gate, new_state else: return new_h + inputs * output_gate, new_state else: return new_h, new_state
def _testDropoutWrapper(self, batch_size=None, time_steps=None, parallel_iterations=None, **kwargs): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): if batch_size is None and time_steps is None: # 2 time steps, batch size 1, depth 3 batch_size = 1 time_steps = 2 x = constant_op.constant([[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) m = rnn_cell_impl.LSTMStateTuple(*[ constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32) ] * 2) else: x = constant_op.constant( np.random.randn(time_steps, batch_size, 3).astype(np.float32)) m = rnn_cell_impl.LSTMStateTuple(*[ constant_op.constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) ] * 2) outputs, final_state = rnn.dynamic_rnn( cell=rnn_cell_impl.DropoutWrapper( rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs), time_major=True, parallel_iterations=parallel_iterations, inputs=x, initial_state=m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([outputs, final_state]) self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) self.assertEqual(res[1].c.shape, (batch_size, 3)) self.assertEqual(res[1].h.shape, (batch_size, 3)) return res
def call(self, inputs, state): """Long short-term memory cell (LSTM). Args: inputs: `2-D` tensor with shape `[batch_size x input_size]`. state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size x self.state_size]`, if `state_is_tuple` has been set to `True`. Otherwise, a `Tensor` shaped `[batch_size x 2 * self.state_size]`. Returns: A pair containing the new hidden state, and the new state (either a `LSTMStateTuple` or a concatenated state, depending on `state_is_tuple`). Pep8 inspection appears since this signature is not same as `call` in tensorflow/python/layers/base. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/layers/base.py """ c, h = state # memory cell, hidden unit args = array_ops.concat([inputs, h], 1) concat = self._linear(args, [args.get_shape()[-1], 4 * self._num_units]) i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: i = self._layer_normalization(i, "layer_norm_i") j = self._layer_normalization(j, "layer_norm_j") f = self._layer_normalization(f, "layer_norm_f") o = self._layer_normalization(o, "layer_norm_o") g = self._activation(j) # gating # dropout (recurrent or variational) if self._recurrent_dropout: # recurrent dropout g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) else: # variational dropout i = nn_ops.dropout(i, self._keep_prob, seed=self._seed) g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) f = nn_ops.dropout(f, self._keep_prob, seed=self._seed) o = nn_ops.dropout(o, self._keep_prob, seed=self._seed) gated_in = math_ops.sigmoid(i) * g memory = c * math_ops.sigmoid(f + self._forget_bias) # layer normalization for memory cell (original paper didn't use for memory cell). # if self._layer_norm: # new_c = self._layer_normalization(new_c, "state") new_c = memory + gated_in new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state
def __init__(self, conv_ndims, input_shape, num_channels_out, kernel_shape, convtype='convolution', channel_multiplier=1, use_bias=True, forget_bias=1.0, initializers=None, name="conv_lstm_cell"): """ Construct ConvLSTMCellAP. Args: conv_ndims: Convolution dimensionality (1, 2 or 3) input_shape: Shape of the input as int tuple, excluding batch size and time step. E.g. (height, width, num_channels) for images num_channels_out: Number of output channels of the convLSTM kernel_shape: Shape of the kernels as int tuple (of size 1, 2 or 3). convtype: convLSTM type - 'convolution': standard convLSTM layer - 'spatial': convolution is separated spatial (n,n) = (n,1) + (1,n) - 'depthwise': convolution is separated depthwise - 'separable': depthwise separable convolution (after depth-wise CONV, a 1x1 convolution is applied over all channels) channel_multiplier: Channel multiplier for depthwise CONVs use_bias: Whether to use bias in convolutions initializers: Unused name: Name of the module Raises: ValueError: If `input_shape` is incompatible with `conv_ndims` or chose type of convolution """ super(ConvLSTMCellAP, self).__init__(name=name) if conv_ndims != len(input_shape) - 1: raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(input_shape, conv_ndims)) self._conv_ndims = conv_ndims self._input_shape = input_shape self._num_channels_out = num_channels_out self._kernel_shape = kernel_shape self._use_bias = use_bias self._convtype = convtype self._channel_multiplier = channel_multiplier self._forget_bias = forget_bias self._total_output_channels = num_channels_out self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._total_output_channels]) cell_state_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._num_channels_out]) self._state_size = rnn_cell_impl.LSTMStateTuple(cell_state_size, self._output_size)
def __init__(self, num_units, use_peepholdes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=True, activation=math_ops.tanh, reuse=None): super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn( "%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) self._num_units = num_units self._use_peepholes = use_peepholdes self._initializer = initializer self._num_proj = num_proj self._proj_clip = proj_clip self._num_unit_shards = num_unit_shards self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation self._reuse = reuse if num_proj: self._state_size = (rnn_cell_impl.LSTMStateTuple( num_units, num_proj) if state_is_tuple else num_units + num_proj) self._output_size = num_proj else: self._state_size = (rnn_cell_impl.LSTMStateTuple( num_units, num_units) if state_is_tuple else 2 * num_units) self._output_size = num_units
def call(self, inputs, state, scope=None): cell, hidden = state new_hidden = _conv([inputs, hidden], self._kernel_shape, 4 * self._output_channels, self._use_bias) gates = array_ops.split(value=new_hidden, num_or_size_splits=4, axis=3) input_gate, new_input, forget_gate, output_gate = gates new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) if self._skip_connection: output = array_ops.concat([output, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state
def __init__(self, num_units, num_proj=None, use_biases=False, input_layer_norm=False, layer_norm=False, reuse=None): """Initialize the parameters for a NAS cell. Args: num_units: int, The number of units in the NAS cell num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. use_biases: (optional) bool, If True then use biases within the cell. This is False by default. layer_norm: (optional) bool, whether to use layer normalization. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(LayerNormNASCell, self).__init__(_reuse=reuse) self._num_units = num_units self._num_proj = num_proj self._use_biases = use_biases self._input_layer_norm = input_layer_norm self._layer_norm = layer_norm self._reuse = reuse if num_proj is not None: self._state_size = rnn_cell_impl.LSTMStateTuple( num_units, num_proj) self._output_size = num_proj else: self._state_size = rnn_cell_impl.LSTMStateTuple( num_units, num_units) self._output_size = num_units
def call(self, inputs, state): (c, h), fast_weights = state batch_size = array_ops.shape(fast_weights)[0] add = math_ops.add multiply = math_ops.multiply sigmoid = math_ops.sigmoid scalar_mul = math_ops.scalar_mul # Parameters of gates are concatenated into one multiply for efficiency. gate_inputs = math_ops.matmul(array_ops.concat([inputs, h], 1), self._kernel) gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) if self._use_layer_norm: gate_inputs = layers.layer_norm(gate_inputs) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=gate_inputs, num_or_size_splits=4, axis=1) fast_j = self._activation(j) expand_fast_j = array_ops.expand_dims(fast_j, 1) fast_weights = add( scalar_mul(self._fast_learning_rate, fast_weights), scalar_mul( self._fast_decay_rate, math_ops.matmul(array_ops.transpose(expand_fast_j, [0, 2, 1]), expand_fast_j))) fast_weights_j = math_ops.matmul( gen_array_ops.reshape(fast_j, [batch_size, 1, -1]), fast_weights) fast_weights_j = gen_array_ops.reshape(fast_weights_j, [batch_size, self._num_units]) fast_j = self._activation(add(fast_j, fast_weights_j)) # Note that using `add` and `multiply` instead of `+` and `*` gives a # performance improvement. So using those at the cost of readability. new_c = add(multiply(c, sigmoid(add(f, self._forget_bias))), multiply(sigmoid(i), fast_j)) if self._use_layer_norm: new_c = layers.layer_norm(new_c) new_h = multiply(self._activation(new_c), sigmoid(o)) return new_h, FastWeightsStateTuple( rnn_cell_impl.LSTMStateTuple(new_c, new_h), fast_weights)
def call(self, inputs, state, time): """LSTM cell with layer normalization and recurrent dropout.""" state_index_in_group = tf.mod(time, self._group_size) group_index = tf.floor_div(time, self._group_size) replicate_index = tf.mod(group_index, self._num_replicates) c, h = state args = array_ops.concat([inputs, h], -1) concat = self._linear(args) dtype = args.dtype i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=-1) if self._layer_norm: i = self._norm(i, "input", dtype=dtype) j = self._norm(j, "transform", dtype=dtype) f = self._norm(f, "forget", dtype=dtype) o = self._norm(o, "output", dtype=dtype) g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) #(i,g,f,o) = (tf.expand_dims(val, -1) for val in (i,g,f,o)) new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: new_c = self._norm(new_c, "state", dtype=dtype) new_h = self._activation(new_c) * math_ops.sigmoid(o) new_h_current = tf.gather(new_h, replicate_index, axis=1) #here we reset the correct state (but only if we reached the end of the group) tmp = 1 - tf.scatter_nd( tf.expand_dims(tf.expand_dims(replicate_index, 0), 0), tf.constant([1.0]), tf.constant([self._num_replicates])) reset_mask = tf.expand_dims(tf.expand_dims(tmp, 0), -1) reset_flag = tf.equal(state_index_in_group + 1, self._group_size) new_c_reset = tf.cond(reset_flag, lambda: new_c * reset_mask, lambda: new_c) new_h_reset = tf.cond(reset_flag, lambda: new_h * reset_mask, lambda: new_h) new_state = rnn_cell_impl.LSTMStateTuple(new_c_reset, new_h_reset) return (new_h_current, new_h), new_state
def call(self, inputs, state): num_proj = self._num_units if self._num_proj is None else self._num_proj if self._state_is_tuple: (c_prev, m_prev) = state else: c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) scope = vs.get_variable_scope() with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: if self._num_unit_shards is not None: unit_scope.set_partitioner( partitioned_variables.fixed_size_partitioner( self._num_unit_shards)) _, output_h, output_c = self._CudnnLSTM( input_data=array_ops.expand_dims(inputs, [0]), input_h=array_ops.expand_dims(m_prev, [0]), input_c=array_ops.expand_dims(c_prev, [0]), params=self._params) c = array_ops.squeeze(output_c, [0]) m = array_ops.squeeze(output_h, [0]) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._num_proj is not None: with vs.variable_scope("projection") as proj_scope: if self._num_proj_shards is not None: proj_scope.set_partitioner( partitioned_variables.fixed_size_partitioner( self._num_proj_shards)) m = rnn_cell_impl._linear(m, self._num_proj, bias=False) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type new_state = (rnn_cell_impl.LSTMStateTuple(c, m) if self._state_is_tuple else array_ops.concat([c, m], 1)) return m, new_state
def call(self, inputs, state, scope=None): cell, hidden = state #print('cell shape:',cell.shape) #print('hidden shape:',hidden.shape) if self._kind == "JANET": new_hidden = _conv([inputs, hidden], self._kernel_shape, 2 * self._output_channels, self._use_bias, self._t_max, self._kind) gates = array_ops.split(value=new_hidden, num_or_size_splits=2, axis=self._conv_ndims + 1) new_input, forget_gate = gates #print('new_input:',new_input.shape) #print('forget_gate:',forget_gate.shape) new_cell = math_ops.sigmoid(forget_gate) * cell + ( 1 - math_ops.sigmoid(forget_gate)) * math_ops.tanh( new_input / 3) output = new_cell elif self._kind == "LSTM": new_hidden = _conv([inputs, hidden], self._kernel_shape, 4 * self._output_channels, self._use_bias, self._t_max, self._kind) gates = array_ops.split(value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) input_gate, new_input, forget_gate, output_gate = gates print('input_gate:', input_gate.shape) print('new_input:', new_input.shape) print('forget_gate:', forget_gate.shape) print('output_gate:', output_gate.shape) new_cell = math_ops.sigmoid(forget_gate) * cell new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh( new_input / 3) output = math_ops.tanh( new_cell / 3) * math_ops.sigmoid(output_gate) if self._skip_connection: output = array_ops.concat([output, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state
def call(self, inputs, state): """Long short-term memory cell (LSTM). Args: inputs: `2-D` tensor with shape `[batch_size, input_size]`. state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a `Tensor` shaped `[batch_size, 2 * num_units]`. Returns: A pair containing the new hidden state, and the new state (either a `LSTMStateTuple` or a concatenated state, depending on `state_is_tuple`). """ if len(state) != 2: raise ValueError("Expecting state to be a tuple with length 2.") if False: #self._use_peephole: wci = self._w_i_diag wcf = self._w_f_diag wco = self._w_o_diag else: wci = wcf = wco = array_ops.zeros([self._num_units]) (cs_prev, h_prev) = state (_, cs, _, _, _, _, h) = xsmm_lstm.xsmm_lstm_cell(x=inputs, cs_prev=cs_prev, h_prev=h_prev, w=self._kernel, w_t=self._kernel_trans, wci=wci, wcf=wcf, wco=wco, b=self._bias, forget_bias=self._forget_bias, cell_clip=-1, use_peephole=False, w_in_kcck=self._w_in_kcck, name=self._name) new_state = rnn_cell_impl.LSTMStateTuple(cs, h) return h, new_state
def call(self, inputs, state, scope=None): cell, hidden = state # split state tupel in last cell state c_t-1 and last hidden state h_t-1 use_unoptimized_convs = True new_hidden = _convs_unoptimized([inputs, hidden], self._kernel_shape, 4 * self._num_channels_out, self._use_bias, convtype=self._convtype) # Channels of new_hidden are concatenation of tensors for different gates and intermediate result for next # cell state -> split into those tensors gates = array_ops.split(value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) input_gate, new_input, forget_gate, output_gate = gates new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) output = math_ops.sigmoid(output_gate) * math_ops.tanh(new_cell) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state
def testLSTMCellLayerNorm(self): with self.test_session() as sess: num_units = 2 num_proj = 3 batch_size = 1 input_size = 4 with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([batch_size, input_size]) c = array_ops.zeros([batch_size, num_units]) h = array_ops.zeros([batch_size, num_proj]) state = rnn_cell_impl.LSTMStateTuple(c, h) cell = contrib_rnn_cell.LayerNormLSTMCell(num_units=num_units, num_proj=num_proj, forget_bias=1.0, layer_norm=True, norm_gain=1.0, norm_shift=0.0) g, out_m = cell(x, state) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g, out_m], { x.name: np.ones((batch_size, input_size)), c.name: 0.1 * np.ones((batch_size, num_units)), h.name: 0.1 * np.ones((batch_size, num_proj)) }) self.assertEqual(len(res), 2) # The numbers in results were not calculated, this is mostly just a # smoke test. self.assertEqual(res[0].shape, (batch_size, num_proj)) self.assertEqual(res[1][0].shape, (batch_size, num_units)) self.assertEqual(res[1][1].shape, (batch_size, num_proj)) # Different inputs so different outputs and states for i in range(1, batch_size): self.assertTrue( float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
def __call__(self, x, states_prev, scope=None): """Long short-term memory cell (LSTM).""" with vs.variable_scope(scope or self._names["scope"]): x_shape = x.get_shape().with_rank(2) if not x_shape[1].value: raise ValueError("Expecting x_shape[1] to be set: %s" % str(x_shape)) if len(states_prev) != 2: raise ValueError( "Expecting states_prev to be a tuple with length 2.") input_size = x_shape[1].value w = vs.get_variable( self._names["W"], [input_size + self._num_units, self._num_units * 4]) b = vs.get_variable(self._names["b"], [w.get_shape().with_rank(2)[1].value], initializer=init_ops.constant_initializer(0.0)) if self._use_peephole: wci = vs.get_variable(self._names["wci"], [self._num_units]) wco = vs.get_variable(self._names["wco"], [self._num_units]) wcf = vs.get_variable(self._names["wcf"], [self._num_units]) else: wci = wco = wcf = array_ops.zeros([self._num_units]) (cs_prev, h_prev) = states_prev (_, cs, _, _, _, _, h) = _lstm_block_cell(x, cs_prev, h_prev, w, b, wci=wci, wco=wco, wcf=wcf, forget_bias=self._forget_bias, cell_clip=None if self._clip_cell else -1, use_peephole=self._use_peephole) new_state = rnn_cell_impl.LSTMStateTuple(cs, h) return h, new_state
def call(self, inputs, state): """LSTM cell with layer normalization and recurrent dropout.""" c, h = state args = array_ops.concat([inputs, h], 1) concat = self._norm(self._linear(args), 'all_norm') i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: new_c = self._norm(new_c, "state") new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state
def call(self, inputs, state, scope=None): if self.dy_adj > 0 and self._input_dim is not None: whole_input_dim = inputs.get_shape().as_list() dy_f_dim = whole_input_dim[-1] - self._input_dim if dy_f_dim > 0: _input, dy_f = tf.split( inputs, num_or_size_splits=[self._input_dim, dy_f_dim], axis=-1) #print('we have dynamic flow data.') inputs = _input else: dy_f = None else: dy_f = None # cell, hidden = state new_hidden = self._conv(args=[inputs, hidden], filter_size=self._kernel_shape, num_features=4 * self._output_channels, bias=self._use_bias, bias_start=0, dy_f=dy_f) gates = array_ops.split(value=new_hidden, num_or_size_splits=4, axis=3) input_gate, new_input, forget_gate, output_gate = gates new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) if self._skip_connection: output = array_ops.concat([output, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) if self.output_dy_adj > 0: print(output.get_shape().as_list()) print(dy_f.get_shape().as_list()) output = tf.concat([output, dy_f], axis=-1) return output, new_state
def call(self, inputs, state, scope=None): cell, hidden = state # with vs.variable_scope(scope, reuse=tf.AUTO_REUSE): new_hidden = _conv([inputs, hidden], self._kernel_shape, 4 * self._output_channels, self._use_bias, dilations=1, name="kernel") gates = array_ops.split(value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) input_gate, new_input, forget_gate, output_gate = gates new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) if self._skip_connection: output = array_ops.concat([output, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state
def call(self, inputs, state, scope=None): cell, hidden = state # C_{t-1},H_{t-1} new_hidden = _conv([inputs, hidden], self._kernel_shape, 4 * self._output_channels, self._use_bias) gates = array_ops.split(value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) input_gate, new_input, forget_gate, output_gate = gates w_ci = vs.get_variable("w_ci", cell.shape, inputs.dtype) w_cf = vs.get_variable("w_cf", cell.shape, inputs.dtype) w_co = vs.get_variable("w_co", cell.shape, inputs.dtype) new_cell = math_ops.sigmoid(forget_gate + self._forget_bias + w_cf * cell) * cell new_cell += math_ops.sigmoid(input_gate + w_ci * cell) * math_ops.tanh(new_input) output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate + w_co * new_cell) if self._skip_connection: output = array_ops.concat([output, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state
def call(self, inputs, state): """2D Convolutional LSTM cell with (optional) normalization and recurrent dropout.""" c, h = state tile_concat = isinstance(inputs, (list, tuple)) if tile_concat: inputs, inputs_non_spatial = inputs args = array_ops.concat([inputs, h], -1) concat = self._conv2d(args) if tile_concat: concat = concat + self._dense(inputs_non_spatial)[:, None, None, :] if self._normalizer_fn and not self._separate_norms: concat = self._norm(concat, "input_transform_forget_output") i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=-1) if self._normalizer_fn and self._separate_norms: i = self._norm(i, "input") j = self._norm(j, "transform") f = self._norm(f, "forget") o = self._norm(o, "output") g = self._activation_fn(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._normalizer_fn: new_c = self._norm(new_c, "state") new_h = self._activation_fn(new_c) * math_ops.sigmoid(o) if self._skip_connection: new_h = array_ops.concat([new_h, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state
def call(self, inputs, state): # print(state) # print('|'*50) # self.inputs = inputs shapes = inputs.get_shape() step_size = shapes[1].value c, h = state gate_inputs = tf.add( tf.matmul(tf.concat([inputs, h], 1), self._kernel), self._bias) i, j, f, o = tf.split(value=gate_inputs, num_or_size_splits=4, axis=1) new_c = tf.add(tf.multiply(c, tf.nn.sigmoid(f)), tf.multiply(i, tf.nn.tanh(j))) new_h = tf.multiply(tf.nn.sigmoid(o), tf.nn.tanh(new_c)) print(new_h) cross_dot = tf.multiply( inputs, tf.tile(new_h.expand_dim(1), [1, step_size, 1])) alpha = tf.transpose(tf.nn.softmax(tf.transpose(cross_dot, [0, 2, 1])), [0, 2, 1]) # tf.multiply(alpha, inputs) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state
def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ sigmoid = math_ops.sigmoid num_proj = self._num_units if self._num_proj is None else self._num_proj if self._state_is_tuple: (c_prev, m_prev) = state else: c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") # Input gate weights self.w_xi = tf.get_variable("_w_xi", [input_size.value, self._num_units]) self.w_hi = tf.get_variable("_w_hi", [self._num_units, self._num_units]) self.w_ci = tf.get_variable("_w_ci", [self._num_units, self._num_units]) # Output gate weights self.w_xo = tf.get_variable("_w_xo", [input_size.value, self._num_units]) self.w_ho = tf.get_variable("_w_ho", [self._num_units, self._num_units]) self.w_co = tf.get_variable("_w_co", [self._num_units, self._num_units]) # Cell weights self.w_xc = tf.get_variable("_w_xc", [input_size.value, self._num_units]) self.w_hc = tf.get_variable("_w_hc", [self._num_units, self._num_units]) # Initialize the bias vectors self.b_i = tf.get_variable("_b_i", [self._num_units], initializer=init_ops.zeros_initializer()) self.b_c = tf.get_variable("_b_c", [self._num_units], initializer=init_ops.zeros_initializer()) self.b_o = tf.get_variable("_b_o", [self._num_units], initializer=init_ops.zeros_initializer()) i_t = sigmoid(math_ops.matmul(inputs, self.w_xi) + math_ops.matmul(m_prev, self.w_hi) + math_ops.matmul(c_prev, self.w_ci) + self.b_i) c_t = ((1 - i_t) * c_prev + i_t * self._activation(math_ops.matmul(inputs, self.w_xc) + math_ops.matmul(m_prev, self.w_hc) + self.b_c)) o_t = sigmoid(math_ops.matmul(inputs, self.w_xo) + math_ops.matmul(m_prev, self.w_ho) + math_ops.matmul(c_t, self.w_co) + self.b_o) h_t = o_t * self._activation(c_t) new_state = (rnn_cell_impl.LSTMStateTuple(c_t, h_t) if self._state_is_tuple else array_ops.concat([c_t, h_t], 1)) return h_t, new_state
def __init__(self, input_shape, filters, kernel_size, forget_bias=1.0, activation_fn=math_ops.tanh, normalizer_fn=None, separate_norms=True, norm_gain=1.0, norm_shift=0.0, dropout_keep_prob=1.0, dropout_prob_seed=None, skip_connection=False, reuse=None): """Initializes the basic convolutional LSTM cell. Args: input_shape: int tuple, Shape of the input, excluding the batch size. filters: int, The number of filters of the conv LSTM cell. kernel_size: int tuple, The kernel size of the conv LSTM cell. forget_bias: float, The bias added to forget gates (see above). activation_fn: Activation function of the inner states. normalizer_fn: If specified, this normalization will be applied before the internal nonlinearities. separate_norms: If set to `False`, the normalizer_fn is applied to the concatenated tensor that follows the convolution, i.e. before splitting the tensor. This case is slightly faster but it might be functionally different, depending on the normalizer_fn (it's functionally the same for instance norm but not for layer norm). Default: `True`. norm_gain: float, The layer normalization gain initial value. If `normalizer_fn` is `None`, this argument will be ignored. norm_shift: float, The layer normalization shift initial value. If `normalizer_fn` is `None`, this argument will be ignored. dropout_keep_prob: unit Tensor or float between 0 and 1 representing the recurrent dropout probability value. If float and 1.0, no dropout will be applied. dropout_prob_seed: (optional) integer, the randomness seed. skip_connection: If set to `True`, concatenate the input to the output of the conv LSTM. Default: `False`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(BasicConv2DLSTMCell, self).__init__(_reuse=reuse) self._input_shape = input_shape self._filters = filters self._kernel_size = list(kernel_size) if isinstance( kernel_size, (tuple, list)) else [kernel_size] * 2 self._forget_bias = forget_bias self._activation_fn = activation_fn self._normalizer_fn = normalizer_fn self._separate_norms = separate_norms self._g = norm_gain self._b = norm_shift self._keep_prob = dropout_keep_prob self._seed = dropout_prob_seed self._skip_connection = skip_connection self._reuse = reuse if self._skip_connection: output_channels = self._filters + self._input_shape[-1] else: output_channels = self._filters cell_size = tensor_shape.TensorShape(self._input_shape[:-1] + [self._filters]) self._output_size = tensor_shape.TensorShape(self._input_shape[:-1] + [output_channels]) self._state_size = rnn_cell_impl.LSTMStateTuple( cell_size, self._output_size)
def call(self, inputs, initial_state=None, dtype=None, sequence_length=None): """Run this LSTM on inputs, starting from the given state. Args: inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` or a list of `time_len` tensors of shape `[batch_size, input_size]`. initial_state: a tuple `(initial_cell_state, initial_output)` with tensors of shape `[batch_size, self._num_units]`. If this is not provided, the cell is expected to create a zero initial state of type `dtype`. dtype: The data type for the initial state and expected output. Required if `initial_state` is not provided or RNN state has a heterogeneous dtype. sequence_length: Specifies the length of each sequence in inputs. An `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, time_len).` Defaults to `time_len` for each element. Returns: A pair containing: - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]` or a list of time_len tensors of shape `[batch_size, output_size]`, to match the type of the `inputs`. - Final state: a tuple `(cell_state, output)` matching `initial_state`. Raises: ValueError: in case of shape mismatches """ is_list = isinstance(inputs, list) if is_list: inputs = array_ops.stack(inputs) inputs_shape = inputs.get_shape().with_rank(3) if not inputs_shape[2]: raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape) batch_size = inputs_shape[1].value if batch_size is None: batch_size = array_ops.shape(inputs)[1] time_len = inputs_shape[0].value if time_len is None: time_len = array_ops.shape(inputs)[0] # Provide default values for initial_state and dtype if initial_state is None: if dtype is None: raise ValueError( "Either initial_state or dtype needs to be specified") z = array_ops.zeros(array_ops.stack([batch_size, self.num_units]), dtype=dtype) initial_state = z, z else: if len(initial_state) != 2: raise ValueError( "Expecting initial_state to be a tuple with length 2 or None" ) if dtype is None: dtype = initial_state[0].dtype # create the actual cell if sequence_length is not None: sequence_length = ops.convert_to_tensor(sequence_length) initial_cell_state, initial_output = initial_state # pylint: disable=unpacking-non-sequence cell_states, outputs = self._call_cell(inputs, initial_cell_state, initial_output, dtype, sequence_length) if sequence_length is not None: # Mask out the part beyond sequence_length mask = array_ops.transpose( array_ops.sequence_mask(sequence_length, time_len, dtype=dtype), [1, 0]) mask = array_ops.tile(array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units]) outputs *= mask # Prepend initial states to cell_states and outputs for indexing to work # correctly,since we want to access the last valid state at # sequence_length - 1, which can even be -1, corresponding to the # initial state. mod_cell_states = array_ops.concat( [array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0) mod_outputs = array_ops.concat( [array_ops.expand_dims(initial_output, [0]), outputs], 0) final_cell_state = self._gather_states(mod_cell_states, sequence_length, batch_size) final_output = self._gather_states(mod_outputs, sequence_length, batch_size) else: # No sequence_lengths used: final state is the last state final_cell_state = cell_states[-1] final_output = outputs[-1] if is_list: # Input was a list, so return a list outputs = array_ops.unstack(outputs) final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output) return outputs, final_state
def state_size(self): return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size self.num_topics = config.num_topics with tf.name_scope("io"): # all dialog context and known attributes self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None, self.max_utt_len), name="dialog_context") self.floors = tf.placeholder(dtype=tf.float32, shape=(None, None), name="floor") # TODO float self.floor_labels = tf.placeholder(dtype=tf.float32, shape=(None, 1), name="floor_labels") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="context_lens") self.paragraph_topics = tf.placeholder(dtype=tf.float32, shape=(None, self.num_topics), name="paragraph_topics") # target response given the dialog context self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_lens") self.output_das = tf.placeholder(dtype=tf.float32, shape=(None, self.num_topics), name="output_dialog_acts") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") max_dialog_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask # embed the input input_embedding = embedding_ops.embedding_lookup( embedding, tf.reshape(self.input_contexts, [-1])) input_embedding = tf.reshape( input_embedding, [-1, self.max_utt_len, config.embed_size]) # encode input using RNN w/GRU sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") # reshape input input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) # floor = probability that the next sentence is the last # TODO do we want this? floor = tf.reshape(self.floors, [-1, max_dialog_len, 1]) joint_embedding = tf.concat([input_embedding, floor], 2, "joint_embedding") with variable_scope.variable_scope("contextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, joint_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: if config.cell_type == 'lstm': enc_last_state = [temp.h for temp in enc_last_state] enc_last_state = tf.concat(enc_last_state, 1) else: if config.cell_type == 'lstm': enc_last_state = enc_last_state.h # Final output from the encoder encoded_list = [self.paragraph_topics, enc_last_state] encoded_embedding = tf.concat(encoded_list, 1) with variable_scope.variable_scope("generationNetwork"): # predict whether the next sentence is the last one # TODO do we want this? self.paragraph_end_logits = layers.fully_connected( encoded_embedding, 1, activation_fn=tf.tanh, scope="paragraph_end_fc1") # Decoder if config.num_layer > 1: dec_init_state = [] for i in range(config.num_layer): temp_init = layers.fully_connected(encoded_embedding, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': # initializer thing for lstm temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state.append(temp_init) dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(encoded_embedding, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state = rnn_cell.LSTMStateTuple( dec_init_state, dec_init_state) with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) # projects into thing of vocab size. TODO no softmax? dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=None) dec_input_embedding = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train( dec_init_state, None) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: # get make of keep/throw-away keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens, name='output_node') if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] # correct word tokens label_mask = tf.to_float(tf.sign(labels)) # Loss between words print "dec outs shape", dec_outs.get_shape() print "labels shape", labels.get_shape() # Loss between words rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) # Predict 0/1 (1 = last sentence in paragraph) end_loss = tf.nn.softmax_cross_entropy_with_logits( labels=self.floor_labels, logits=self.paragraph_end_logits) self.avg_end_loss = tf.reduce_mean(end_loss) print "size of end loss", self.avg_end_loss.get_shape() total_loss = self.avg_rc_loss + self.avg_end_loss tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("paragraph_end_loss", self.avg_end_loss) self.summary_op = tf.summary.merge_all() self.optimize(sess, config, total_loss, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
def __init__(self, numUnits, usePeepholes=False, initializer=None, numProj=None, projClip=None, numUnitShards=1, numProShards=1, forgetBias=1.0, stateIsTuple=True, activation=math_ops.tanh, reuse=None): """Initialize the parameters for an LSTM cell. Args: numUnits: int, The number of units in the LSTM cell usePeepholes: bool, set True to enable diagonal/peephole connections. initializer: (optional) The initializer to use for the weight and projection matrices. numProj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. projClip: (optional) A float value. If `numProj > 0` and `projClip` is provided, then the projected values are clipped elementwise to within `[-projClip, projClip]`. numUnitShards: How to split the weight matrix. If >1, the weight matrix is stored across numUnitShards. numProShards: How to split the projection matrix. If >1, the projection matrix is stored across numProShards. forgetBias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. stateIsTuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. By default (False), they are concatenated along the column axis. This default behavior will soon be deprecated. activation: Activation function of the inner states. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not stateIsTuple: logging.warn( "%s: Using a concatenated state is slower and will soon be " "deprecated. Use stateIsTuple=True.", self) self._numUnits = numUnits self._usePeepholes = usePeepholes self._initializer = initializer self._numProj = numProj self._projClip = projClip self._numUnitShards = numUnitShards self._numProShards = numProShards self._forgetBias = forgetBias self._stateIsTuple = stateIsTuple self._activation = activation self._reuse = reuse if numProj: self._state_size = (rnn_cell_impl.LSTMStateTuple( numUnits, numProj) if stateIsTuple else numUnits + numProj) self._output_size = numProj else: self._state_size = (rnn_cell_impl.LSTMStateTuple( numUnits, numUnits) if stateIsTuple else 2 * numUnits) self._output_size = numUnits
def call(self, inputs, state): """Run one step of NAS Cell. Args: inputs: input Tensor, 2D, batch x num_units. state: This must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the NAS Cell after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of NAS Cell after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ if self._input_layer_norm: inputs = tf.contrib.layers.layer_norm(scope="inputs_ln", inputs=inputs, reuse=tf.AUTO_REUSE) sigmoid = math_ops.sigmoid tanh = math_ops.tanh selu = tf.nn.selu num_proj = self._num_units if self._num_proj is None else self._num_proj (c_prev, m_prev) = state dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError( "Could not infer input size from inputs.get_shape()[-1]") # Variables for the NAS cell. W_m is all matrices multiplying the # hidden state and W_inputs is all matrices multiplying the inputs. concat_w_m = tf.get_variable("recurrent_kernel", [num_proj, 8 * self._num_units], dtype) concat_w_inputs = tf.get_variable( "kernel", [input_size.value, 8 * self._num_units], dtype) m_matrix = math_ops.matmul(m_prev, concat_w_m) inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) if self._use_biases: b = tf.get_variable("bias", shape=[8 * self._num_units], initializer=init_ops.zeros_initializer(), dtype=dtype) m_matrix = nn_ops.bias_add(m_matrix, b) if self._layer_norm: m_matrix = tf.contrib.layers.layer_norm(scope="m_matrix_ln", inputs=m_matrix, reuse=tf.AUTO_REUSE) # inputs_matrix = tf.contrib.layers.layer_norm( # scope="inputs_matrix_ln", # inputs=inputs_matrix, # reuse=tf.AUTO_REUSE # ) # The NAS cell branches into 8 different splits for both the hidden state # and the input m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, value=m_matrix) inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, value=inputs_matrix) # First layer layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) layer1_1 = selu(inputs_matrix_splits[1] + m_matrix_splits[1]) layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) layer1_3 = selu(inputs_matrix_splits[3] * m_matrix_splits[3]) layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) # Second layer l2_0 = tanh(layer1_0 * layer1_1) l2_1 = tanh(layer1_2 + layer1_3) l2_2 = tanh(layer1_4 * layer1_5) l2_3 = sigmoid(layer1_6 + layer1_7) # Inject the cell l2_0 = tanh(l2_0 + c_prev) # Third layer l3_0_pre = l2_0 * l2_1 new_c = l3_0_pre # create new cell l3_0 = l3_0_pre l3_1 = tanh(l2_2 + l2_3) # Final layer new_m = tanh(l3_0 * l3_1) # Projection layer if specified if self._num_proj is not None: concat_w_proj = tf.get_variable("projection_weights", [self._num_units, self._num_proj], dtype) new_m = math_ops.matmul(new_m, concat_w_proj) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) return new_m, new_state
def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: A tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ dtype = inputs.dtype num_units = self._num_units sigmoid = math_ops.sigmoid c, h = state input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError( "Could not infer input size from inputs.get_shape()[-1]") with vs.variable_scope(self._scope, initializer=self._initializer): concat = self._linear([inputs, h], 4 * num_units, norm=self._norm, bias=True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._use_peepholes: w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) new_c = (c * sigmoid(f + self._forget_bias + w_f_diag * c) + sigmoid(i + w_i_diag * c) * self._activation(j)) else: new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) else: new_h = sigmoid(o) * self._activation(new_c) if self._num_proj is not None: with vs.variable_scope("projection"): new_h = self._linear(new_h, self._num_proj, norm=self._norm, bias=False) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type new_h = clip_ops.clip_by_value(new_h, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state