def test_IterRefLSTM_pickle(): """Tests that IterRefLSTM can be pickled.""" n_feat = 10 max_depth = 5 n_test = 5 n_support = 5 tg = TensorGraph() test = Feature(shape=(None, n_feat)) support = Feature(shape=(None, n_feat)) lstm = IterRefLSTMEmbedding( n_test, n_support, n_feat, max_depth, in_layers=[test, support]) tg.add_output(lstm) tg.set_loss(lstm) tg.build() tg.save()
def test_iter_ref_lstm_embedding(self): """Test that IterRef LSTM computation works properly.""" max_depth = 5 n_test = 5 n_support = 11 n_feat = 10 test = np.random.rand(n_test, n_feat) support = np.random.rand(n_support, n_feat) with self.session() as sess: test = tf.convert_to_tensor(test, dtype=tf.float32) support = tf.convert_to_tensor(support, dtype=tf.float32) iter_ref_embedding_layer = IterRefLSTMEmbedding( n_test, n_support, n_feat, max_depth) out_tensor = iter_ref_embedding_layer(test, support) sess.run(tf.global_variables_initializer()) test_out, support_out = out_tensor[0].eval(), out_tensor[1].eval() assert test_out.shape == (n_test, n_feat) assert support_out.shape == (n_support, n_feat)