def test_sample_prob(self): with self.session() as session: sampling_rate = 1e-5 min_count = 1 vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) word_count_table = vocab_utils.get_tf_word_count_table( w2v._words, w2v._counts) test_data_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'data.txt') dataset = (tf.data.TextLineDataset(test_data_filepath) .map(tf.strings.strip) .map(lambda x: tf.strings.split([x]).to_sparse())) iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) x = iterator.get_next() tokens = tf.convert_to_tensor(value=x.values.numpy()) prob = datasets_utils.sample_prob( tokens, sampling_rate, word_count_table, w2v._total_count) sample = lambda x: 1 - math.sqrt(sampling_rate / (x / w2v._total_count)) self.assertAllEqual( prob, tf.constant( [sample(y) for y in word_count_table.lookup(tokens).numpy()], dtype=tf.float64))
def test_get_cbow_train_dataset(self): with self.session() as session: tf.compat.v1.disable_tensor_equality() # tf.compat.v1.disable_eager_execution() sampling_rate = 1. window_size = 5 min_count = 50 batch_size = 2 num_epochs = 1 p_num_threads = 1 shuffling_buffer_size = 1 vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) test_data_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'data.txt') dataset = datasets_utils.get_w2v_train_dataset( test_data_filepath, 'cbow', w2v._words, w2v._counts, w2v._total_count, window_size, sampling_rate, batch_size, num_epochs, p_num_threads, shuffling_buffer_size) iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) features, labels = iterator.get_next() self.assertAllEqual( features, tf.constant([[b'is', b'a', b'that', b'-', b'on', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_'], [b'anarchism', b'a', b'that', b'-', b'on', b'.', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_']])) self.assertAllEqual(labels, tf.constant([[b'anarchism'], [b'is']]))
def _train(args): logger.info('Training Tensorflow implementation of Word2Vec') output_model_dirpath = futils.get_model_dirpath( args.datafile, args.outputdir, args.train_mode, args.alpha, args.neg, args.window, args.sample, args.epochs, args.min_count, args.size, args.batch) w2v = Word2Vec() if not args.vocab or (args.vocab and not os.path.exists(args.vocab)): if not args.datafile: raise Exception( 'Unspecified data_filepath. You need to specify the data ' 'file from which to build the vocabulary, or to specify a ' 'valid vocabulary filepath') if args.vocab and not os.path.exists(args.vocab): logger.warning('The specified vocabulary filepath does not seem ' 'to exist: {}'.format(args.vocab)) logger.warning('Re-building vocabulary from scratch') vocab_filepath = futils.get_vocab_filepath(args.datafile, output_model_dirpath) w2v.build_vocab(args.datafile, vocab_filepath, args.min_count) else: w2v.load_vocab(args.vocab, args.min_count) w2v.train(args.train_mode, args.datafile, output_model_dirpath, args.batch, args.size, args.neg, args.alpha, args.window, args.epochs, args.sample, args.p_num_threads, args.t_num_threads, args.shuffling_buffer_size, args.save_summary_steps, args.save_checkpoints_steps, args.keep_checkpoint_max, args.log_step_count_steps, args.debug, args.debug_port, args.xla)
def test_sample_tokens(self): with self.test_session() as session: min_count = 50 sampling_rate = 1. test_data_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'data.txt') vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) word_count_table = vocab_utils.get_tf_word_count_table( w2v._words, w2v._counts) tf.tables_initializer().run() dataset = (tf.data.TextLineDataset(test_data_filepath).map( tf.strings.strip).map(lambda x: tf.strings.split([x]))) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer x = iterator.get_next() session.run(init_op) self.assertAllEqual( datasets_utils.sample_tokens(x.values, sampling_rate, word_count_table, w2v._total_count), tf.constant( [b'anarchism', b'is', b'a', b'that', b'-', b'on', b'.']))
def test_get_tf_vocab_table(self): with self.session(): min_count = 1 vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) vocab = vocab_utils.get_tf_vocab_table(w2v._words) self.assertAllEqual( vocab.lookup(tf.constant(['anarchism', 'is', 'UKN@!', '1711'])), tf.constant([0, 1, len(w2v._words), len(w2v._words)-1]))
def test_get_word_count_table(self): with self.session(): min_count = 1 vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) word_count = vocab_utils.get_tf_word_count_table( w2v._words, w2v._counts) self.assertAllEqual( word_count.lookup(tf.constant(['anarchism', 'is', 'UKN@!', '1711'])), tf.constant([112, 283, 0, 1])) freq = word_count.lookup( tf.constant(['anarchism', 'is', 'UKN@!', '1711'])) / w2v._total_count self.assertAllEqual( freq, tf.constant([112/26084, 283/26084, 0, 1/26084], dtype=tf.float64))
def test_get_cbow_train_dataset(self): with self.test_session() as session: sampling_rate = 1. window_size = 5 min_count = 50 batch_size = 2 num_epochs = 1 p_num_threads = 1 shuffling_buffer_size = 1 vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) test_data_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'data.txt') dataset = datasets_utils.get_w2v_train_dataset( test_data_filepath, 'cbow', w2v._words, w2v._counts, w2v._total_count, window_size, sampling_rate, batch_size, num_epochs, p_num_threads, shuffling_buffer_size) tf.tables_initializer().run() iterator = dataset.make_initializable_iterator() init_op = iterator.initializer x = iterator.get_next() session.run(init_op) features, labels = session.run(x) self.assertAllEqual( features, tf.constant([[ b'is', b'a', b'that', b'-', b'on', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_' ], [ b'anarchism', b'a', b'that', b'-', b'on', b'.', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_', b'_CBOW#_!MASK_' ]])) self.assertAllEqual(labels, tf.constant([[b'anarchism'], [b'is']]))
def test_filter_tokens_mask(self): with self.session() as session: min_count = 50 sampling_rate = 1. test_data_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'data.txt') vocab_filepath = os.path.join(os.path.dirname(__file__), 'resources', 'wiki.test.vocab') w2v = Word2Vec() w2v.load_vocab(vocab_filepath, min_count) word_count_table = vocab_utils.get_tf_word_count_table( w2v._words, w2v._counts) dataset = (tf.data.TextLineDataset(test_data_filepath) .map(tf.strings.strip) .map(lambda x: tf.strings.split([x]).to_sparse())) iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) x = iterator.get_next() self.assertAllEqual(datasets_utils.filter_tokens_mask( x.values, sampling_rate, word_count_table, w2v._total_count), tf.constant( [True, True, True, False, False, True, False, False, True, False, False, False, True, False, False, True]))