def test_sample_prob(self):
     with self.session() as session:
         sampling_rate = 1e-5
         min_count = 1
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         word_count_table = vocab_utils.get_tf_word_count_table(
             w2v._words, w2v._counts)
         test_data_filepath = os.path.join(os.path.dirname(__file__),
                                           'resources', 'data.txt')
         dataset = (tf.data.TextLineDataset(test_data_filepath)
                    .map(tf.strings.strip)
                    .map(lambda x: tf.strings.split([x]).to_sparse()))
         iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
         x = iterator.get_next()
         tokens = tf.convert_to_tensor(value=x.values.numpy())
         prob = datasets_utils.sample_prob(
             tokens, sampling_rate, word_count_table, w2v._total_count)
         sample = lambda x: 1 - math.sqrt(sampling_rate / (x / w2v._total_count))
         self.assertAllEqual(
             prob, tf.constant(
                 [sample(y) for y in word_count_table.lookup(tokens).numpy()],
                 dtype=tf.float64))
 def test_get_cbow_train_dataset(self):
     with self.session() as session:
         tf.compat.v1.disable_tensor_equality()
         # tf.compat.v1.disable_eager_execution()
         sampling_rate = 1.
         window_size = 5
         min_count = 50
         batch_size = 2
         num_epochs = 1
         p_num_threads = 1
         shuffling_buffer_size = 1
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         test_data_filepath = os.path.join(os.path.dirname(__file__),
                                           'resources', 'data.txt')
         dataset = datasets_utils.get_w2v_train_dataset(
             test_data_filepath, 'cbow', w2v._words, w2v._counts,
             w2v._total_count, window_size, sampling_rate, batch_size,
             num_epochs, p_num_threads, shuffling_buffer_size)
         iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
         features, labels = iterator.get_next()
         self.assertAllEqual(
             features, tf.constant([[b'is', b'a', b'that', b'-', b'on',
                                     b'_CBOW#_!MASK_', b'_CBOW#_!MASK_',
                                     b'_CBOW#_!MASK_', b'_CBOW#_!MASK_',
                                     b'_CBOW#_!MASK_'],
                                    [b'anarchism', b'a', b'that', b'-',
                                     b'on', b'.', b'_CBOW#_!MASK_',
                                     b'_CBOW#_!MASK_', b'_CBOW#_!MASK_',
                                     b'_CBOW#_!MASK_']]))
         self.assertAllEqual(labels, tf.constant([[b'anarchism'], [b'is']]))
Example #3
0
def _train(args):
    logger.info('Training Tensorflow implementation of Word2Vec')
    output_model_dirpath = futils.get_model_dirpath(
        args.datafile, args.outputdir, args.train_mode, args.alpha, args.neg,
        args.window, args.sample, args.epochs, args.min_count, args.size,
        args.batch)
    w2v = Word2Vec()
    if not args.vocab or (args.vocab and not os.path.exists(args.vocab)):
        if not args.datafile:
            raise Exception(
                'Unspecified data_filepath. You need to specify the data '
                'file from which to build the vocabulary, or to specify a '
                'valid vocabulary filepath')
        if args.vocab and not os.path.exists(args.vocab):
            logger.warning('The specified vocabulary filepath does not seem '
                           'to exist: {}'.format(args.vocab))
            logger.warning('Re-building vocabulary from scratch')
        vocab_filepath = futils.get_vocab_filepath(args.datafile,
                                                   output_model_dirpath)
        w2v.build_vocab(args.datafile, vocab_filepath, args.min_count)
    else:
        w2v.load_vocab(args.vocab, args.min_count)
    w2v.train(args.train_mode, args.datafile, output_model_dirpath, args.batch,
              args.size, args.neg, args.alpha, args.window, args.epochs,
              args.sample, args.p_num_threads, args.t_num_threads,
              args.shuffling_buffer_size, args.save_summary_steps,
              args.save_checkpoints_steps, args.keep_checkpoint_max,
              args.log_step_count_steps, args.debug, args.debug_port, args.xla)
Example #4
0
 def test_sample_tokens(self):
     with self.test_session() as session:
         min_count = 50
         sampling_rate = 1.
         test_data_filepath = os.path.join(os.path.dirname(__file__),
                                           'resources', 'data.txt')
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         word_count_table = vocab_utils.get_tf_word_count_table(
             w2v._words, w2v._counts)
         tf.tables_initializer().run()
         dataset = (tf.data.TextLineDataset(test_data_filepath).map(
             tf.strings.strip).map(lambda x: tf.strings.split([x])))
         iterator = dataset.make_initializable_iterator()
         init_op = iterator.initializer
         x = iterator.get_next()
         session.run(init_op)
         self.assertAllEqual(
             datasets_utils.sample_tokens(x.values, sampling_rate,
                                          word_count_table,
                                          w2v._total_count),
             tf.constant(
                 [b'anarchism', b'is', b'a', b'that', b'-', b'on', b'.']))
Example #5
0
 def test_get_tf_vocab_table(self):
     with self.session():
         min_count = 1
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         vocab = vocab_utils.get_tf_vocab_table(w2v._words)
         self.assertAllEqual(
             vocab.lookup(tf.constant(['anarchism', 'is', 'UKN@!',
                                       '1711'])),
             tf.constant([0, 1, len(w2v._words),
                          len(w2v._words)-1]))
Example #6
0
 def test_get_word_count_table(self):
     with self.session():
         min_count = 1
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         word_count = vocab_utils.get_tf_word_count_table(
             w2v._words, w2v._counts)
         self.assertAllEqual(
             word_count.lookup(tf.constant(['anarchism', 'is', 'UKN@!',
                                            '1711'])),
             tf.constant([112, 283, 0, 1]))
         freq = word_count.lookup(
             tf.constant(['anarchism', 'is', 'UKN@!', '1711'])) / w2v._total_count
         self.assertAllEqual(
             freq, tf.constant([112/26084, 283/26084, 0, 1/26084],
                               dtype=tf.float64))
Example #7
0
 def test_get_cbow_train_dataset(self):
     with self.test_session() as session:
         sampling_rate = 1.
         window_size = 5
         min_count = 50
         batch_size = 2
         num_epochs = 1
         p_num_threads = 1
         shuffling_buffer_size = 1
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         test_data_filepath = os.path.join(os.path.dirname(__file__),
                                           'resources', 'data.txt')
         dataset = datasets_utils.get_w2v_train_dataset(
             test_data_filepath, 'cbow', w2v._words, w2v._counts,
             w2v._total_count, window_size, sampling_rate, batch_size,
             num_epochs, p_num_threads, shuffling_buffer_size)
         tf.tables_initializer().run()
         iterator = dataset.make_initializable_iterator()
         init_op = iterator.initializer
         x = iterator.get_next()
         session.run(init_op)
         features, labels = session.run(x)
         self.assertAllEqual(
             features,
             tf.constant([[
                 b'is', b'a', b'that', b'-', b'on', b'_CBOW#_!MASK_',
                 b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_',
                 b'_CBOW#_!MASK_'
             ],
                          [
                              b'anarchism', b'a', b'that', b'-', b'on',
                              b'.', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_',
                              b'_CBOW#_!MASK_', b'_CBOW#_!MASK_'
                          ]]))
         self.assertAllEqual(labels, tf.constant([[b'anarchism'], [b'is']]))
 def test_filter_tokens_mask(self):
     with self.session() as session:
         min_count = 50
         sampling_rate = 1.
         test_data_filepath = os.path.join(os.path.dirname(__file__),
                                           'resources', 'data.txt')
         vocab_filepath = os.path.join(os.path.dirname(__file__),
                                       'resources', 'wiki.test.vocab')
         w2v = Word2Vec()
         w2v.load_vocab(vocab_filepath, min_count)
         word_count_table = vocab_utils.get_tf_word_count_table(
             w2v._words, w2v._counts)
         dataset = (tf.data.TextLineDataset(test_data_filepath)
                    .map(tf.strings.strip)
                    .map(lambda x: tf.strings.split([x]).to_sparse()))
         iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
         x = iterator.get_next()
         self.assertAllEqual(datasets_utils.filter_tokens_mask(
             x.values, sampling_rate, word_count_table, w2v._total_count),
                             tf.constant(
                                 [True, True, True, False, False, True,
                                  False, False, True, False, False, False,
                                  True, False, False, True]))