示例#1
0
    def testPretrainedEmbeddingsLoading(self):
        embedding_file = os.path.join(self.get_temp_dir(), "embedding.txt")
        vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt")

        with io.open(embedding_file, encoding="utf-8", mode="w") as embedding:
            embedding.write(u"toto 1 1\n" u"titi 2 2\n" u"tata 3 3\n")
        with io.open(vocab_file, encoding="utf-8", mode="w") as vocab:
            vocab.write(u"Toto\n" u"tOTO\n" u"tata\n" u"tete\n")

        embeddings = text_inputter.load_pretrained_embeddings(
            embedding_file,
            vocab_file,
            num_oov_buckets=1,
            with_header=False,
            case_insensitive_embeddings=True)
        self.assertAllEqual([5, 2], embeddings.shape)
        self.assertAllEqual([1, 1], embeddings[0])
        self.assertAllEqual([1, 1], embeddings[1])
        self.assertAllEqual([3, 3], embeddings[2])

        embeddings = text_inputter.load_pretrained_embeddings(
            embedding_file,
            vocab_file,
            num_oov_buckets=2,
            with_header=False,
            case_insensitive_embeddings=False)
        self.assertAllEqual([6, 2], embeddings.shape)
        self.assertAllEqual([3, 3], embeddings[2])
示例#2
0
    def testPretrainedEmbeddingsLoading(self):
        vocab_file = self._makeTextFile("vocab.txt",
                                        ["Toto", "tOTO", "tata", "tete"])
        embedding_file = self._makeEmbeddingsFile([("toto", [1, 1]),
                                                   ("titi", [2, 2]),
                                                   ("tata", [3, 3])])

        embeddings = text_inputter.load_pretrained_embeddings(
            embedding_file,
            vocab_file,
            num_oov_buckets=1,
            with_header=False,
            case_insensitive_embeddings=True,
        )
        self.assertAllEqual([5, 2], embeddings.shape)
        self.assertAllEqual([1, 1], embeddings[0])
        self.assertAllEqual([1, 1], embeddings[1])
        self.assertAllEqual([3, 3], embeddings[2])

        embeddings = text_inputter.load_pretrained_embeddings(
            embedding_file,
            vocab_file,
            num_oov_buckets=2,
            with_header=False,
            case_insensitive_embeddings=False,
        )
        self.assertAllEqual([6, 2], embeddings.shape)
        self.assertAllEqual([3, 3], embeddings[2])
示例#3
0
  def testPretrainedEmbeddingsLoading(self):
    with open(embedding_file, "w") as embedding:
      embedding.write("toto 1 1\n"
                      "titi 2 2\n"
                      "tata 3 3\n")
    with open(vocab_file, "w") as vocab:
      vocab.write("Toto\n"
                  "tOTO\n"
                  "tata\n"
                  "tete\n")

    embeddings = text_inputter.load_pretrained_embeddings(
        embedding_file,
        vocab_file,
        num_oov_buckets=1,
        case_insensitive_embeddings=True)
    self.assertAllEqual([5, 2], embeddings.shape)
    self.assertAllEqual([1, 1], embeddings[0])
    self.assertAllEqual([1, 1], embeddings[1])
    self.assertAllEqual([3, 3], embeddings[2])

    embeddings = text_inputter.load_pretrained_embeddings(
        embedding_file,
        vocab_file,
        num_oov_buckets=2,
        case_insensitive_embeddings=False)
    self.assertAllEqual([6, 2], embeddings.shape)
    self.assertAllEqual([3, 3], embeddings[2])
示例#4
0
  def testPretrainedEmbeddingsLoading(self):
    with open(embedding_file, "w") as embedding:
      embedding.write("toto 1 1\n"
                      "titi 2 2\n"
                      "tata 3 3\n")
    with open(vocab_file, "w") as vocab:
      vocab.write("Toto\n"
                  "tOTO\n"
                  "tata\n"
                  "tete\n")

    embeddings = text_inputter.load_pretrained_embeddings(
        embedding_file,
        vocab_file,
        num_oov_buckets=1,
        with_header=False,
        case_insensitive_embeddings=True)
    self.assertAllEqual([5, 2], embeddings.shape)
    self.assertAllEqual([1, 1], embeddings[0])
    self.assertAllEqual([1, 1], embeddings[1])
    self.assertAllEqual([3, 3], embeddings[2])

    embeddings = text_inputter.load_pretrained_embeddings(
        embedding_file,
        vocab_file,
        num_oov_buckets=2,
        with_header=False,
        case_insensitive_embeddings=False)
    self.assertAllEqual([6, 2], embeddings.shape)
    self.assertAllEqual([3, 3], embeddings[2])
示例#5
0
def load_embeddings(embedding_file, vocab_file):
    """Loads an embedding variable or embeddings file."""
    try:
        embeddings = tf.get_variable("embedding")
    except ValueError:
        pretrained = load_pretrained_embeddings(embedding_file, vocab_file, num_oov_buckets=1,\
                                                with_header=True, case_insensitive_embeddings=True)
        embeddings = tf.get_variable("embedding", shape=None, trainable=False,\
                                     initializer=tf.constant(pretrained.astype(np.float32)))
    return embeddings
示例#6
0
def load_embeddings(embedding_file, vocab_file):
  """Loads an embedding variable or embeddings file."""
  try:
    embeddings = tf.get_variable("embedding")
  except ValueError:
    pretrained = load_pretrained_embeddings(
        embedding_file,
        vocab_file,
        num_oov_buckets=1,
        with_header=True,
        case_insensitive_embeddings=True)
    embeddings = tf.get_variable(
        "embedding",
        shape=None,
        trainable=False,
        initializer=tf.constant(pretrained.astype(np.float32)))
  return embeddings
示例#7
0
    def testPretrainedEmbeddingsWithHeaderLoading(self):
        embedding_file = os.path.join(self.get_temp_dir(), "embedding.txt")
        vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt")

        with open(embedding_file, "w") as embedding:
            embedding.write("3 2\n" "toto 1 1\n" "titi 2 2\n" "tata 3 3\n")
        with open(vocab_file, "w") as vocab:
            vocab.write("Toto\n" "tOTO\n" "tata\n" "tete\n")

        embeddings = text_inputter.load_pretrained_embeddings(
            embedding_file,
            vocab_file,
            num_oov_buckets=1,
            case_insensitive_embeddings=True)
        self.assertAllEqual([5, 2], embeddings.shape)
        self.assertAllEqual([1, 1], embeddings[0])
        self.assertAllEqual([1, 1], embeddings[1])
        self.assertAllEqual([3, 3], embeddings[2])