def testClassifierGraph(self): FLAGS.rnn_num_layers = 2 model = graphs.VatxtModel() train_op, _, _ = model.classifier_training() # Pretrained vars: embedding + LSTM layers self.assertEqual(len(model.pretrained_variables), 1 + 2 * FLAGS.rnn_num_layers) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) tf.train.start_queue_runners(sess) sess.run(train_op)
def testATMethods(self): at_methods = [None, 'rp', 'at', 'vat', 'atvat'] for method in at_methods: FLAGS.adv_training_method = method with tf.Graph().as_default(): graphs.VatxtModel().classifier_graph() # Ensure variables have been reused # Embedding + LSTM layers + hidden layers + logits layer expected_num_vars = 1 + 2 * FLAGS.rnn_num_layers + 2 * ( FLAGS.cl_num_layers) + 2 self.assertEqual(len(tf.trainable_variables()), expected_num_vars)
def testEvalGraph(self): _, _ = graphs.VatxtModel().eval_graph()
def testSeqAE(self): FLAGS.use_seq2seq_autoencoder = True graphs.VatxtModel().language_model_training()
def testCandidateSampling(self): FLAGS.num_candidate_samples = 10 graphs.VatxtModel().language_model_training()
def testSyncReplicas(self): FLAGS.sync_replicas = True graphs.VatxtModel().language_model_training()
def testMulticlass(self): FLAGS.num_classes = 10 graphs.VatxtModel().classifier_graph()
def testLanguageModelGraph(self): train_op, _, _ = graphs.VatxtModel().language_model_training() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) tf.train.start_queue_runners(sess) sess.run(train_op)