def construct_flat_datasets(args1, subwords_path): global tokenizer_bert, tokenizer_ro, args args = args1 if args.bert: tokenizer_bert = FullTokenizer( vocab_file=join(args.bert_model_dir, "vocab.vocab")) tokenizer_bert.vocab_size = len(tokenizer_bert.vocab) samples = get_text_samples(args) if os.path.isfile(subwords_path + '.subwords'): tokenizer_ro = construct_tokenizer(None, subwords_path, args) else: tokenizer_ro = construct_tokenizer(list(samples), subwords_path, args) sample_train = int(args.total_samples * args.train_dev_split) if args.records: dataset = tf.data.Dataset.from_generator( generator_tensors_ids_and_segs, ((tf.int64, tf.int64), tf.int64), ((tf.TensorShape([None]), tf.TensorShape( [None])), tf.TensorShape([None]))) if args.separate: train_dataset = dataset dev_dataset = tf.data.Dataset.from_generator( generator_tensors_ids_and_segs_dev, ((tf.int64, tf.int64), tf.int64), ((tf.TensorShape( [None]), tf.TensorShape([None])), tf.TensorShape([None]))) return train_dataset, dev_dataset else: gen_dataset = generator_tensors_ids() dataset = list(gen_dataset) nr_samples = len(dataset) sample_train = int(args.train_dev_split * nr_samples) # dataset = tf.convert_to_tensor(dataset, dtype=tf.int64) dataset = tf.data.Dataset.from_generator( generator_tensors_ids, (tf.int64, tf.int64), (tf.TensorShape( [2, args.seq_length]), tf.TensorShape([args.seq_length]))) if args.separate: gen_dataset = generator_tensors_ids_dev() dev_dataset = list(gen_dataset) # dataset = tf.convert_to_tensor(dataset, dtype=tf.int64) dev_dataset = tf.data.Dataset.from_generator( generator_tensors_ids_dev, (tf.int64, tf.int64), (tf.TensorShape([2, args.seq_length ]), tf.TensorShape([args.seq_length]))) return dataset, dev_dataset train_dataset = dataset.take(sample_train) dev_dataset = dataset.skip(sample_train) return train_dataset, dev_dataset
def get_tokenizers_ckeckpoint(args1): global args args = args1 tokenizer_ro_path = join(args.checkpoint, 'tokenizer_ro') tokenizer_ro = tfds.features.text.SubwordTextEncoder.load_from_file(tokenizer_ro_path) tf.compat.v1.logging.info('restoring ro tokenizer from {}'.format(tokenizer_ro_path)) tokenizer_bert = None if args.bert: tokenizer_bert_path = join(args.checkpoint, 'tokenizer_bert.vocab') tokenizer_bert = FullTokenizer(vocab_file=tokenizer_bert_path) tokenizer_bert.vocab_size = len(tokenizer_bert.vocab) tf.compat.v1.logging.info('restoring bert tokenizer from {}'.format(tokenizer_bert_path)) tf.compat.v1.logging.info('tokenizers restored') return tokenizer_ro, tokenizer_bert