示例#1
0
def single_cell(cell_class,
                cell_params,
                dp_input_keep_prob=1.0,
                dp_output_keep_prob=1.0,
                residual_connections=False):
    """Creates an instance of the rnn cell.
     Such cell describes one step one layer and can include residual connection
     and/or dropout

     Args:
       cell_class: Tensorflow RNN cell class
       cell_params (dict): cell parameters
       dp_input_keep_prob (float): (default: 1.0) input dropout keep probability
       dp_output_keep_prob (float): (default: 1.0) output dropout keep probability
       residual_connections (bool): whether to add residual connection

     Returns:
       TF RNN instance
    """
    cell = cell_class(**cell_params)
    if residual_connections:
        cell = ResidualWrapper(cell)
    if dp_input_keep_prob != 1.0 or dp_output_keep_prob != 1.0:
        cell = DropoutWrapper(cell,
                              input_keep_prob=dp_input_keep_prob,
                              output_keep_prob=dp_output_keep_prob)
    return cell
示例#2
0
def single_cell(
    cell_class,
    cell_params,
    dp_input_keep_prob=1.0,
    dp_output_keep_prob=1.0,
    recurrent_keep_prob=1.0,
    input_weight_keep_prob=1.0,
    recurrent_weight_keep_prob=1.0,
    weight_variational=False,
    dropout_seed=None,
    zoneout_prob=0.0,
    training=True,
    residual_connections=False,
    awd_initializer=False,
    variational_recurrent=False,  # in case they want to use DropoutWrapper
    dtype=None,
):
    """Creates an instance of the rnn cell.
     Such cell describes one step one layer and can include residual connection
     and/or dropout
     Args:
      cell_class: Tensorflow RNN cell class
      cell_params (dict): cell parameters
      dp_input_keep_prob (float): (default: 1.0) input dropout keep
        probability.
      dp_output_keep_prob (float): (default: 1.0) output dropout keep
        probability.
      zoneout_prob(float): zoneout probability. Applying both zoneout and
        droupout is currently not supported
      residual_connections (bool): whether to add residual connection
     Returns:
       TF RNN instance
  """
    if awd_initializer:
        val = 1.0 / math.sqrt(cell_params['num_units'])
        cell_params['initializer'] = tf.random_uniform_initializer(minval=-val,
                                                                   maxval=val)

    cell = cell_class(**cell_params)
    if residual_connections:
        cell = ResidualWrapper(cell)
    if zoneout_prob > 0.0 and (dp_input_keep_prob < 1.0
                               or dp_output_keep_prob < 1.0):
        raise ValueError(
            'Currently applying both dropout and zoneout on the same cell.'
            'This is currently not supported.')
    if dp_input_keep_prob != 1.0 or dp_output_keep_prob != 1.0 and training:
        cell = DropoutWrapper(
            cell,
            input_keep_prob=dp_input_keep_prob,
            output_keep_prob=dp_output_keep_prob,
            variational_recurrent=variational_recurrent,
            dtype=dtype,
            seed=dropout_seed,
        )
    if zoneout_prob > 0.0:
        cell = ZoneoutWrapper(cell, zoneout_prob, is_training=training)
    return cell
示例#3
0
 def _create_single_rnn_cell(self, num_units):
     cell = GRUCell(
         num_units) if self.cfg["cell_type"] == "gru" else LSTMCell(
             num_units)
     if self.cfg["use_dropout"]:
         cell = DropoutWrapper(cell, output_keep_prob=self.rnn_keep_prob)
     if self.cfg["use_residual"]:
         cell = ResidualWrapper(cell)
     return cell
 def _create_rnn_cell(self):
     cell = GRUCell(
         self.cfg.num_units) if self.cfg.cell_type == "gru" else LSTMCell(
             self.cfg.num_units)
     if self.cfg.use_dropout:
         cell = DropoutWrapper(cell, output_keep_prob=self.keep_prob)
     if self.cfg.use_residual:
         cell = ResidualWrapper(cell)
     return cell
示例#5
0
def setup_cell(cell_type, size, use_residual=False, keep_prob=None):
  cell = getattr(tf.contrib.rnn, cell_type)(size)
  if keep_prob is not None:
    cell = DropoutWrapper(cell, 
                          input_keep_prob=1.0,
                          output_keep_prob=keep_prob)
  if use_residual:
    cell = ResidualWrapper(cell)
  return cell
示例#6
0
  def single_cell(cell_params):
    # TODO: This method is ugly - redo
    size = cell_params["num_units"]
    if cell_type == "lstm":
      cell_class = LSTMCell
    elif cell_type == "gru":
      cell_class = GRUCell
    elif cell_type == "glstm":
      cell_class = GLSTMCell
      num_groups = 4#cell_params["num_groups"]

    if residual_connections:
      if dp_input_keep_prob !=1.0 or dp_output_keep_prob != 1.0:
        if cell_type != "glstm":
          return DropoutWrapper(ResidualWrapper(cell_class(num_units=size)),
                              input_keep_prob=dp_input_keep_prob,
                              output_keep_prob=dp_output_keep_prob)
        else:
          return DropoutWrapper(ResidualWrapper(cell_class(num_units=size, number_of_groups=num_groups)),
                                input_keep_prob=dp_input_keep_prob,
                                output_keep_prob=dp_output_keep_prob)
      else:
        if cell_type != "glstm":
          return ResidualWrapper(cell_class(num_units=size))
        else:
          return ResidualWrapper(cell_class(num_units=size, number_of_groups=num_groups))
    else:
      if dp_input_keep_prob !=1.0 or dp_output_keep_prob != 1.0:
        if cell_type != "glstm":
          return DropoutWrapper(cell_class(num_units=size),
                              input_keep_prob=dp_input_keep_prob,
                              output_keep_prob=dp_output_keep_prob)
        else:
          return DropoutWrapper(cell_class(num_units=size, number_of_groups=num_groups),
                                input_keep_prob=dp_input_keep_prob,
                                output_keep_prob=dp_output_keep_prob)
      else:
        if cell_type != "glstm":
          return cell_class(num_units=size)
        else:
          return cell_class(num_units=size, number_of_groups=num_groups)
示例#7
0
def build_decoder_cell(rank, u_emb, batch_size, depth=2):
  cell = []
  for i in range(depth):
    if i == 0:
      cell.append(LSTMCell(rank, state_is_tuple=True))
    else:
      cell.append(ResidualWrapper(LSTMCell(rank, state_is_tuple=True)))
  initial_state = LSTMStateTuple(tf.zeros_like(u_emb), u_emb)
  initial_state = [initial_state, ]
  for i in range(1, depth):
    initial_state.append(cell[i].zero_state(batch_size, tf.float32))
  return MultiRNNCell(cell), tuple(initial_state)
示例#8
0
    def single_cell(num_units):
        if cell_type == "lstm":
            cell_class = LSTMCell
        elif cell_type == "gru":
            cell_class = GRUCell

        if residual_connections:
            if dp_input_keep_prob != 1.0 or dp_output_keep_prob != 1.0:
                return DropoutWrapper(ResidualWrapper(
                    cell_class(num_units=num_units)),
                                      input_keep_prob=dp_input_keep_prob,
                                      output_keep_prob=dp_output_keep_prob)
            else:
                return ResidualWrapper(cell_class(num_units=num_units))
        else:
            if dp_input_keep_prob != 1.0 or dp_output_keep_prob != 1.0:
                return DropoutWrapper(cell_class(num_units=num_units),
                                      input_keep_prob=dp_input_keep_prob,
                                      output_keep_prob=dp_output_keep_prob)
            else:
                return cell_class(num_units=num_units)
示例#9
0
    def build_single_cell(self):
        cell_type = LSTMCell
        if (self.cell_type.lower() == 'gru'):
            cell_type = GRUCell
        cell = cell_type(self.hidden_units)

        if self.use_dropout:
            cell = DropoutWrapper(cell, dtype=self.dtype,
                                  output_keep_prob=self.keep_prob_placeholder,)
        if self.use_residual:
            cell = ResidualWrapper(cell)
            
        return cell
示例#10
0
    def build_single_cell(self, hidden_size=None):
        if (self.p.cell_type.lower() == 'gru'): cell_type = GRUCell
        else: cell_type = LSTMCell

        if not hidden_size: hidden_size = self.p.hidden_size
        cell = cell_type(hidden_size)

        if self.p.use_dropout:
            cell = DropoutWrapper(
                cell, dtype=self.p.dtype,
                output_keep_prob=self.p.keep_prob)  #change this
        if self.p.use_residual: cell = ResidualWrapper(cell)

        return cell
示例#11
0
 def _build_single_cell(self):
     if self.cell_type == 'lstm':
         cell_type = LSTMCell
     elif self.cell_type == 'gru':
         cell_type = GRUCell
     cell = cell_type(self.hidden_dim)
     if self.use_dropout:
         cell = DropoutWrapper(
             cell,
             dtype=self.dtype,
             output_keep_prob=self.keep_prob_placeholder,
         )
     if self.use_residual:
         cell = ResidualWrapper(cell)
     return cell
示例#12
0
def single_cell(cell_class,
                cell_params,
                dp_input_keep_prob=1.0,
                dp_output_keep_prob=1.0,
                zoneout_prob=0.,
                training=True,
                residual_connections=False):
    """Creates an instance of the rnn cell.
     Such cell describes one step one layer and can include residual connection
     and/or dropout

     Args:
      cell_class: Tensorflow RNN cell class
      cell_params (dict): cell parameters
      dp_input_keep_prob (float): (default: 1.0) input dropout keep
        probability.
      dp_output_keep_prob (float): (default: 1.0) output dropout keep
        probability.
      zoneout_prob(float): zoneout probability. Applying both zoneout and
        droupout is currently not supported
      residual_connections (bool): whether to add residual connection

     Returns:
       TF RNN instance
  """
    cell = cell_class(**cell_params)
    if residual_connections:
        cell = ResidualWrapper(cell)
    if zoneout_prob > 0. and (dp_input_keep_prob < 1.0
                              or dp_output_keep_prob < 1.0):
        raise ValueError(
            "Currently applying both dropout and zoneout on the same cell."
            "This is not recommended")
    if dp_input_keep_prob != 1.0 or dp_output_keep_prob != 1.0:
        cell = DropoutWrapper(cell,
                              input_keep_prob=dp_input_keep_prob,
                              output_keep_prob=dp_output_keep_prob)
    if zoneout_prob > 0.:
        cell = ZoneoutWrapper(cell, zoneout_prob, is_training=training)
    return cell
示例#13
0
    def single_cell(cell_params):
        # TODO: This method is ugly - redo
        size = cell_params["num_units"]
        proj_size = None if "proj_size" not in cell_params else cell_params[
            "proj_size"]

        if cell_type == "lstm":
            if not residual_connections:
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return tf.nn.rnn_cell.LSTMCell(num_units=size,
                                                   num_proj=proj_size,
                                                   forget_bias=1.0)
                else:
                    return DropoutWrapper(tf.nn.rnn_cell.LSTMCell(
                        num_units=size, num_proj=proj_size, forget_bias=1.0),
                                          input_keep_prob=dp_input_keep_prob,
                                          output_keep_prob=dp_output_keep_prob)
            else:  # residual connection required
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return ResidualWrapper(
                        tf.nn.rnn_cell.LSTMCell(num_units=size,
                                                num_proj=proj_size,
                                                forget_bias=1.0))
                else:
                    return ResidualWrapper(
                        DropoutWrapper(
                            tf.nn.rnn_cell.LSTMCell(
                                num_units=size,
                                num_proj=proj_size,
                                forget_bias=1.0,
                            ),
                            input_keep_prob=dp_input_keep_prob,
                            output_keep_prob=dp_output_keep_prob,
                        ))
        elif cell_type == "gru":
            if not residual_connections:
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return tf.nn.rnn_cell.GRUCell(num_units=size)
                else:
                    return DropoutWrapper(
                        tf.nn.rnn_cell.GRUCell(num_units=size),
                        input_keep_prob=dp_input_keep_prob,
                        output_keep_prob=dp_output_keep_prob,
                    )
            else:  # residual connection required
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return ResidualWrapper(
                        tf.nn.rnn_cell.GRUCell(num_units=size))
                else:
                    return ResidualWrapper(
                        DropoutWrapper(tf.nn.rnn_cell.GRUCell(num_units=size),
                                       input_keep_prob=dp_input_keep_prob,
                                       output_keep_prob=dp_output_keep_prob), )
        elif cell_type == "glstm":
            num_groups = cell_params["num_groups"]
            if not residual_connections:
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return GLSTMCell(num_units=size,
                                     number_of_groups=num_groups,
                                     num_proj=proj_size,
                                     forget_bias=1.0)
                else:
                    return DropoutWrapper(GLSTMCell(
                        num_units=size,
                        number_of_groups=num_groups,
                        num_proj=proj_size,
                        forget_bias=1.0),
                                          input_keep_prob=dp_input_keep_prob,
                                          output_keep_prob=dp_output_keep_prob)
            else:  # residual connection required
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return ResidualWrapper(
                        GLSTMCell(num_units=size,
                                  number_of_groups=num_groups,
                                  num_proj=proj_size,
                                  forget_bias=1.0))
                else:
                    return ResidualWrapper(
                        DropoutWrapper(
                            GLSTMCell(
                                num_units=size,
                                number_of_groups=num_groups,
                                num_proj=proj_size,
                                forget_bias=1.0,
                            ),
                            input_keep_prob=dp_input_keep_prob,
                            output_keep_prob=dp_output_keep_prob,
                        ))
        elif cell_type == "slstm":
            if not residual_connections:
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return BasicSLSTMCell(num_units=size)
                else:
                    return DropoutWrapper(BasicSLSTMCell(num_units=size),
                                          input_keep_prob=dp_input_keep_prob,
                                          output_keep_prob=dp_output_keep_prob)
            else:  # residual connection required
                if dp_input_keep_prob == 1.0 and dp_output_keep_prob == 1.0:
                    return ResidualWrapper(BasicSLSTMCell(num_units=size))
                else:
                    return ResidualWrapper(
                        DropoutWrapper(
                            BasicSLSTMCell(num_units=size),
                            input_keep_prob=dp_input_keep_prob,
                            output_keep_prob=dp_output_keep_prob,
                        ))
        else:
            raise ValueError("Unknown RNN cell class: {}".format(cell_type))
示例#14
0
# x_data = np.array([[[1,0,0,0]]],dtype = np.float32)
#
# outputs, _state = tf.nn.dynamic_rnn(cell1, x_data, dtype = tf.float32)
sess.run(tf.global_variables_initializer())
n_layers = 3
cells = []
for i in range(n_layers):
    cell = tf.contrib.rnn.LSTMCell(hidden_size, state_is_tuple=True)

    cell = tf.contrib.rnn.AttentionCellWrapper(cell,
                                               attn_length=40,
                                               state_is_tuple=True)

    cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=0.5)
    # cell = tf.contrib.rnn.ResidualWrapper(cell)
    cell = ResidualWrapper(cell)
    cells.append(cell)

cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)


def build_encoder():
    print("building encoder..")
    with tf.variable_scope('encoder'):

        outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
                                                inputs=X_data,
                                                dtype=tf.float32)


build_encoder()