示例#1
0
    def test_load(self):
        """ The ids should be the same when the index is loaded back up

        """
        dim = 100
        act = 10
        gen = Generator(dim, act)

        signs1 = [str(i) for i in range(1000)]
        index1 = TrieSignIndex(gen, vocabulary=signs1)

        filename = "index.hdf5"
        directory = os.path.dirname(os.path.abspath(__file__))
        index_file = directory + "/" + filename

        self.assertFalse(os.path.exists(index_file))
        try:
            index1.save(index_file)
            self.assertTrue(os.path.exists(index_file))

            index2 = TrieSignIndex.load(index_file)
            self.assertEqual(len(index2),len(index1))

            for sign in signs1:
                self.assertTrue(index1.contains(sign))
                self.assertTrue(index2.contains(sign))
                id1 = index1.get_id(sign)
                id2 = index2.get_id(sign)
                self.assertEqual(id1,id2)

                ri1 = index1.get_ri(sign).to_vector()
                ri2 = index2.get_ri(sign).to_vector()

                np.testing.assert_array_equal(ri1,ri2)
        except:
            raise
        finally:
            if os.path.exists(index_file):
                os.remove(index_file)
        self.assertFalse(os.path.exists(index_file))
示例#2
0
    def test_save(self):
        dim = 100
        act = 10
        gen = Generator(dim, act)

        signs = [str(i) for i in range(10)]
        sign_index = TrieSignIndex(gen, vocabulary=signs)

        filename = "index.hdf5"
        directory = os.path.dirname(os.path.abspath(__file__))
        output_file = directory+"/"+filename

        self.assertFalse(os.path.exists(output_file))
        try:
            sign_index.save(output_file)
            self.assertTrue(os.path.exists(output_file))

            h5file = h5py.File(output_file,'r')

            h5signs = h5file["signs"]
            h5ri = h5file["ri"]

            self.assertEqual(len(h5signs),len(signs))

            print(h5ri[0])
            print(h5ri.attrs["k"])
            print(h5ri.attrs["s"])
            print(h5ri.attrs["state"].tostring())

            h5file.close()
        except:
            raise
        finally:
            if os.path.exists(output_file):
                os.remove(output_file)
        self.assertFalse(os.path.exists(output_file))
示例#3
0
        for e in range(1):
            tf_session.run(train_step, {
                model.input(): x,
                pos_labels(): yp,
                neg_labels(): yn
            })

        x_samples = []
        c_samples = []

    corpus_hdf5.close()
    vocab_hdf5.close()

    print("saving model")
    # save random indexes and model
    sign_index.save(index_file)
    model.save_model(tf_session, model_filename=model_file)

    print("done")

    tf_session.close()

# ======================================================================================
# Process Interrupted
# ======================================================================================
except (KeyboardInterrupt, SystemExit):
    # TODO store the model current state
    # and the state of the corpus iteration
    print("\nProcess interrupted, closing corpus and saving progress...", file=sys.stderr)
    corpus_hdf5.close()
    vocab_hdf5.close()
示例#4
0
        # nrp train on last batch
        if len(x_samples) > 0:
            sess.run(train_step, {
                model.input(): x_samples,
                labels(): y_samples,
            })

            x_samples = []
            y_samples = []

    corpus_hdf5.close()
    vocab_hdf5.close()

    print("saving model")
    # save random indexes and model
    index.save(index_file)
    model.save_model(sess,
                     model_filename=model_file,
                     embeddings_name=model_suffix)

    print("done")

    sess.close()

# ======================================================================================
# Process Interrupted
# ======================================================================================
except (KeyboardInterrupt, SystemExit):
    # TODO save current model progress
    print("\nProcess interrupted, closing corpus and saving progress...",
          file=sys.stderr)