コード例 #1
0
 def import_ops(self):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
         rnn_params = tf.get_collection_ref("rnn_params")
         if self._cell and rnn_params:
             params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
                 self._cell,
                 self._cell.params_to_canonical,
                 self._cell.canonical_to_params,
                 rnn_params,
                 base_variable_scope="Model/RNN")
             tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS,
                                  params_saveable)
     self._cost = tf.get_collection_ref(util.with_prefix(
         self._name, "cost"))[0]
     num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(self._final_state,
                                                  self._final_state_name,
                                                  num_replicas)
コード例 #2
0
 def import_ops(self, config):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
         rnn_params = tf.get_collection_ref("rnn_params")
         if self._cell and rnn_params:
             params_saveable = tf.contrib.cudnn_rnn.CudnnLSTMSaveable(
                 rnn_params,
                 config.num_layers,
                 config.hidden_size,
                 config.hidden_size,
                 scope="Model/RNN")
             tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS,
                                  params_saveable)
     self._cost = tf.get_collection_ref(util.with_prefix(
         self._name, "cost"))[0]
     num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(self._final_state,
                                                  self._final_state_name,
                                                  num_replicas)
コード例 #3
0
    def import_ops(self, num_gpus=1):
        """Imports ops from collections."""

        if self._is_training:
            self._train_op = tf.get_collection_ref("train_op")[0]
            self._lr = tf.get_collection_ref("lr")[0]
            self._new_lr = tf.get_collection_ref("new_lr")[0]
            self._lr_update = tf.get_collection_ref("lr_update")[0]

        self._cost = tf.get_collection_ref(util.with_prefix(
            self._name, "cost"))[0]
        self._kl_loss = tf.get_collection_ref(
            util.with_prefix(self._name, "kl_div"))[0]
        self._input_data = tf.get_collection_ref(
            util.with_prefix(self._name, "input_data"))[0]
        self._output = tf.get_collection_ref(
            util.with_prefix(self._name, "output"))[0]
        self._targets = tf.get_collection_ref(
            util.with_prefix(self._name, "targets"))[0]

        num_replicas = num_gpus if self._name == "Train" else 1
        self._initial_state = util.import_state_tuples(
            self._initial_state, self._initial_state_name, num_replicas)
        self._final_state = util.import_state_tuples(self._final_state,
                                                     self._final_state_name,
                                                     num_replicas)
コード例 #4
0
 def import_ops(self):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
         self._params_size = tf.get_collection_ref("params_size")[0]
         if FLAGS.num_gpus:
             self._memory_use = tf.get_collection_ref("memory_use")[0]
         rnn_params = tf.get_collection_ref("rnn_params")
     #   if self._cell and rnn_params:
     # params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
     #     self._cell,
     #     self._cell.params_to_canonical,
     #     self._cell.canonical_to_params,
     #     rnn_params,
     #     base_variable_scope="Model/RNN")
     # params_saveable = tf.contrib.cudnn_rnn.CudnnLSTMSaveable(
     #     rnn_params,
     #     self._cell.num_layers,
     #     self._cell.num_units,
     #     self._cell.input_size,
     #     self._cell.input_mode,
     #     self._cell.direction,
     #     scope="Model/RNN")
     # tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
     self._cost = tf.get_collection_ref(util.with_prefix(
         self._name, "cost"))[0]
     num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(self._final_state,
                                                  self._final_state_name,
                                                  num_replicas)
コード例 #5
0
ファイル: Model.py プロジェクト: bnuside/emoji_recom
 def import_ops(self):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref('train_op')[0]
         self._lr = tf.get_collection_ref('lr')[0]
         self._new_lr = tf.get_collection_ref('new_lr')[0]
         self._lr_update = tf.get_collection_ref('lr_update')[0]
         rnn_params = tf.get_collection_ref('rnn_params')
         if self._cell and rnn_params:
             params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
                 self._cell,
                 self._cell.params_to_canonical,
                 self._cell.canonical_to_params,
                 rnn_params,
                 base_variable_scope='Model/RNN')
             tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS,
                                  params_saveable)
     if self.model_type == 'test':
         self.logits = tf.get_collection_ref('logits')[0]
         self.y = tf.get_collection_ref('y')[0]
     self._cost = tf.get_collection_ref(util.with_prefix(
         self._name, 'cost'))[0]
     num_replicas = self.num_gpus if self._name == 'Train' else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(self._final_state,
                                                  self._final_state_name,
                                                  num_replicas)
コード例 #6
0
ファイル: ptb_word_lm.py プロジェクト: RM1708/PTB
 def import_ops(self):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
         rnn_params = tf.get_collection_ref("rnn_params")
         if self._cell and rnn_params:
             assert(False, \
                    "This branch spells trouble as the fn RNNParamsSaveable()" + \
                    " is not available from cudnn_rnn ")
             #        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
             #
             # /home/rm/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
             # does have the function RNNParamsSaveable()
             params_saveable = cudnn_rnn.RNNParamsSaveable(
                 self._cell,
                 self._cell.params_to_canonical,
                 self._cell.canonical_to_params,
                 rnn_params,
                 base_variable_scope="Model/RNN")
             tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS,
                                  params_saveable)
     self._cost = tf.get_collection_ref(util.with_prefix(
         self._name, "cost"))[0]
     num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(self._final_state,
                                                  self._final_state_name,
                                                  num_replicas)
コード例 #7
0
ファイル: ptb_word_lm.py プロジェクト: zhangdong86/blog-codes
 def import_ops(self):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
         rnn_params = tf.get_collection_ref("rnn_params")
         if self._cell and rnn_params:
             params_saveable = tf.contrib.cudnn_rnn.CudnnLSTMSaveable(
                 rnn_params,
                 self._cell.num_layers,
                 self._cell.num_units,
                 self._cell.input_size,
                 self._cell.input_mode,
                 self._cell.direction,
                 scope="Model/RNN")
     self._cost = tf.get_collection_ref(util.with_prefix(
         self._name, "cost"))[0]
     num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(self._final_state,
                                                  self._final_state_name,
                                                  num_replicas)
コード例 #8
0
ファイル: model.py プロジェクト: davletov-aa/morph_analysis
  def import_ops(self):
    """Imports ops from collections."""
    self._padding = tf.get_collection_ref(util.with_prefix(self._name, 'clfs_padd'))[0]
    self._last_step = tf.get_collection_ref(util.with_prefix(self._name, 'last_step'))[0]
    if self._is_training:
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      self._new_val = tf.get_collection_ref("new_val")[0]
      self._update_epoch = tf.get_collection_ref("update_epoch")

      for key in self._exported_ops['train_ops']:
        self._train_ops[key] = tf.get_collection_ref(key)[0]
    for key in self._exported_ops['epoch_and_step']:
      if key == f'{self._name}/cur_epoch':
        self._cur_epoch = tf.get_collection_ref(key)[0]
      else:
        self._global_step = tf.get_collection_ref(key)[0]
    for key in self._exported_ops['losses']:
      self._losses[key.split('/')[1]] = tf.get_collection_ref(key)[0]
    for key in self._exported_ops['l2_losses']:
      self._l2_losses[key.split('/')[1]] = tf.get_collection_ref(key)[0]
    for key in self._exported_ops['predictions']:
      self._predictions[key.split('/')[1]] = tf.get_collection_ref(key)[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas)
コード例 #9
0
ファイル: emoji_lm.py プロジェクト: hougrammer/emoji_project
 def import_ops(self):
     """Imports ops from collections."""
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
     self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
     self._final_state = util.import_state_tuples(self._final_state, self._final_state_name)
コード例 #10
0
 def import_ops(self, num_gpus = 1):
     """Imports ops from collections."""
     
     if self._is_training:
         self._train_op = tf.get_collection_ref("train_op")[0]
         self._lr = tf.get_collection_ref("lr")[0]
         self._new_lr = tf.get_collection_ref("new_lr")[0]
         self._lr_update = tf.get_collection_ref("lr_update")[0]
         
     self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
     self._kl_loss = tf.get_collection_ref(util.with_prefix(self._name, "kl_div"))[0]
     self._input_data = tf.get_collection_ref(util.with_prefix(self._name, "input_data"))[0]
     self._output = tf.get_collection_ref(util.with_prefix(self._name, "output"))[0]
     self._targets = tf.get_collection_ref(util.with_prefix(self._name, "targets"))[0]
     
     num_replicas = num_gpus if self._name == "Train" else 1
     self._initial_state = util.import_state_tuples(
         self._initial_state, self._initial_state_name, num_replicas)
     self._final_state = util.import_state_tuples(
         self._final_state, self._final_state_name, num_replicas)
コード例 #11
0
ファイル: ptb_word_lm.py プロジェクト: rwightman/models
 def import_ops(self):
   """Imports ops from collections."""
   if self._is_training:
     self._train_op = tf.get_collection_ref("train_op")[0]
     self._lr = tf.get_collection_ref("lr")[0]
     self._new_lr = tf.get_collection_ref("new_lr")[0]
     self._lr_update = tf.get_collection_ref("lr_update")[0]
     rnn_params = tf.get_collection_ref("rnn_params")
     if self._cell and rnn_params:
       params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
           self._cell,
           self._cell.params_to_canonical,
           self._cell.canonical_to_params,
           rnn_params,
           base_variable_scope="Model/RNN")
       tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
   self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
   num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
   self._initial_state = util.import_state_tuples(
       self._initial_state, self._initial_state_name, num_replicas)
   self._final_state = util.import_state_tuples(
       self._final_state, self._final_state_name, num_replicas)
コード例 #12
0
from __future__ import absolute_import