コード例 #1
0
 def testClassifierGraph(self):
     FLAGS.rnn_num_layers = 2
     model = graphs.VatxtModel()
     train_op, _, _ = model.classifier_training()
     # Pretrained vars: embedding + LSTM layers
     self.assertEqual(len(model.pretrained_variables),
                      1 + 2 * FLAGS.rnn_num_layers)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         tf.train.start_queue_runners(sess)
         sess.run(train_op)
コード例 #2
0
    def testATMethods(self):
        at_methods = [None, 'rp', 'at', 'vat', 'atvat']
        for method in at_methods:
            FLAGS.adv_training_method = method
            with tf.Graph().as_default():
                graphs.VatxtModel().classifier_graph()

                # Ensure variables have been reused
                # Embedding + LSTM layers + hidden layers + logits layer
                expected_num_vars = 1 + 2 * FLAGS.rnn_num_layers + 2 * (
                    FLAGS.cl_num_layers) + 2
                self.assertEqual(len(tf.trainable_variables()),
                                 expected_num_vars)
コード例 #3
0
 def testEvalGraph(self):
     _, _ = graphs.VatxtModel().eval_graph()
コード例 #4
0
 def testSeqAE(self):
     FLAGS.use_seq2seq_autoencoder = True
     graphs.VatxtModel().language_model_training()
コード例 #5
0
 def testCandidateSampling(self):
     FLAGS.num_candidate_samples = 10
     graphs.VatxtModel().language_model_training()
コード例 #6
0
 def testSyncReplicas(self):
     FLAGS.sync_replicas = True
     graphs.VatxtModel().language_model_training()
コード例 #7
0
 def testMulticlass(self):
     FLAGS.num_classes = 10
     graphs.VatxtModel().classifier_graph()
コード例 #8
0
 def testLanguageModelGraph(self):
     train_op, _, _ = graphs.VatxtModel().language_model_training()
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         tf.train.start_queue_runners(sess)
         sess.run(train_op)